1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
#!/usr/bin/env python
'''
Digit recognition adjustment.
Grid search is used to find the best parameters for SVM and KNearest classifiers.
SVM adjustment follows the guidelines given in
http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
Usage:
digits_adjust.py [--model {svm|knearest}]
--model {svm|knearest} - select the classifier (SVM is the default)
'''
# Python 2/3 compatibility
from __future__ import print_function
import sys
PY3 = sys.version_info[0] == 3
if PY3:
xrange = range
import numpy as np
import cv2 as cv
from multiprocessing.pool import ThreadPool
from digits import *
def cross_validate(model_class, params, samples, labels, kfold = 3, pool = None):
n = len(samples)
folds = np.array_split(np.arange(n), kfold)
def f(i):
model = model_class(**params)
test_idx = folds[i]
train_idx = list(folds)
train_idx.pop(i)
train_idx = np.hstack(train_idx)
train_samples, train_labels = samples[train_idx], labels[train_idx]
test_samples, test_labels = samples[test_idx], labels[test_idx]
model.train(train_samples, train_labels)
resp = model.predict(test_samples)
score = (resp != test_labels).mean()
print(".", end='')
return score
if pool is None:
scores = list(map(f, xrange(kfold)))
else:
scores = pool.map(f, xrange(kfold))
return np.mean(scores)
class App(object):
def __init__(self):
self._samples, self._labels = self.preprocess()
def preprocess(self):
digits, labels = load_digits(DIGITS_FN)
shuffle = np.random.permutation(len(digits))
digits, labels = digits[shuffle], labels[shuffle]
digits2 = list(map(deskew, digits))
samples = preprocess_hog(digits2)
return samples, labels
def get_dataset(self):
return self._samples, self._labels
def run_jobs(self, f, jobs):
pool = ThreadPool(processes=cv.getNumberOfCPUs())
ires = pool.imap_unordered(f, jobs)
return ires
def adjust_SVM(self):
Cs = np.logspace(0, 10, 15, base=2)
gammas = np.logspace(-7, 4, 15, base=2)
scores = np.zeros((len(Cs), len(gammas)))
scores[:] = np.nan
print('adjusting SVM (may take a long time) ...')
def f(job):
i, j = job
samples, labels = self.get_dataset()
params = dict(C = Cs[i], gamma=gammas[j])
score = cross_validate(SVM, params, samples, labels)
return i, j, score
ires = self.run_jobs(f, np.ndindex(*scores.shape))
for count, (i, j, score) in enumerate(ires):
scores[i, j] = score
print('%d / %d (best error: %.2f %%, last: %.2f %%)' %
(count+1, scores.size, np.nanmin(scores)*100, score*100))
print(scores)
print('writing score table to "svm_scores.npz"')
np.savez('svm_scores.npz', scores=scores, Cs=Cs, gammas=gammas)
i, j = np.unravel_index(scores.argmin(), scores.shape)
best_params = dict(C = Cs[i], gamma=gammas[j])
print('best params:', best_params)
print('best error: %.2f %%' % (scores.min()*100))
return best_params
def adjust_KNearest(self):
print('adjusting KNearest ...')
def f(k):
samples, labels = self.get_dataset()
err = cross_validate(KNearest, dict(k=k), samples, labels)
return k, err
best_err, best_k = np.inf, -1
for k, err in self.run_jobs(f, xrange(1, 9)):
if err < best_err:
best_err, best_k = err, k
print('k = %d, error: %.2f %%' % (k, err*100))
best_params = dict(k=best_k)
print('best params:', best_params, 'err: %.2f' % (best_err*100))
return best_params
if __name__ == '__main__':
import getopt
import sys
print(__doc__)
args, _ = getopt.getopt(sys.argv[1:], '', ['model='])
args = dict(args)
args.setdefault('--model', 'svm')
args.setdefault('--env', '')
if args['--model'] not in ['svm', 'knearest']:
print('unknown model "%s"' % args['--model'])
sys.exit(1)
t = clock()
app = App()
if args['--model'] == 'knearest':
app.adjust_KNearest()
else:
app.adjust_SVM()
print('work time: %f s' % (clock() - t))
|