1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
|
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
This is an example illustrating the use of the multiclass classification tools
from the dlib C++ Library. Specifically, this example will make points from
three classes and show you how to train a multiclass classifier to recognize
these three classes.
The classes are as follows:
- class 1: points very close to the origin
- class 2: points on the circle of radius 10 around the origin
- class 3: points that are on a circle of radius 4 but not around the origin at all
*/
#include <dlib/svm_threaded.h>
#include <iostream>
#include <vector>
#include <dlib/rand.h>
using namespace std;
using namespace dlib;
// Our data will be 2-dimensional data. So declare an appropriate type to contain these points.
typedef matrix<double,2,1> sample_type;
// ----------------------------------------------------------------------------------------
void generate_data (
std::vector<sample_type>& samples,
std::vector<double>& labels
);
/*!
ensures
- make some 3 class data as described above.
- Create 60 points from class 1
- Create 70 points from class 2
- Create 80 points from class 3
!*/
// ----------------------------------------------------------------------------------------
int main()
{
try
{
std::vector<sample_type> samples;
std::vector<double> labels;
// First, get our labeled set of training data
generate_data(samples, labels);
cout << "samples.size(): "<< samples.size() << endl;
// The main object in this example program is the one_vs_one_trainer. It is essentially
// a container class for regular binary classifier trainer objects. In particular, it
// uses the any_trainer object to store any kind of trainer object that implements a
// .train(samples,labels) function which returns some kind of learned decision function.
// It uses these binary classifiers to construct a voting multiclass classifier. If
// there are N classes then it trains N*(N-1)/2 binary classifiers, one for each pair of
// labels, which then vote on the label of a sample.
//
// In this example program we will work with a one_vs_one_trainer object which stores any
// kind of trainer that uses our sample_type samples.
typedef one_vs_one_trainer<any_trainer<sample_type> > ovo_trainer;
// Finally, make the one_vs_one_trainer.
ovo_trainer trainer;
// Next, we will make two different binary classification trainer objects. One
// which uses kernel ridge regression and RBF kernels and another which uses a
// support vector machine and polynomial kernels. The particular details don't matter.
// The point of this part of the example is that you can use any kind of trainer object
// with the one_vs_one_trainer.
typedef polynomial_kernel<sample_type> poly_kernel;
typedef radial_basis_kernel<sample_type> rbf_kernel;
// make the binary trainers and set some parameters
krr_trainer<rbf_kernel> rbf_trainer;
svm_nu_trainer<poly_kernel> poly_trainer;
poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
rbf_trainer.set_kernel(rbf_kernel(0.1));
// Now tell the one_vs_one_trainer that, by default, it should use the rbf_trainer
// to solve the individual binary classification subproblems.
trainer.set_trainer(rbf_trainer);
// We can also get more specific. Here we tell the one_vs_one_trainer to use the
// poly_trainer to solve the class 1 vs class 2 subproblem. All the others will
// still be solved with the rbf_trainer.
trainer.set_trainer(poly_trainer, 1, 2);
// Now let's do 5-fold cross-validation using the one_vs_one_trainer we just setup.
// As an aside, always shuffle the order of the samples before doing cross validation.
// For a discussion of why this is a good idea see the svm_ex.cpp example.
randomize_samples(samples, labels);
cout << "cross validation: \n" << cross_validate_multiclass_trainer(trainer, samples, labels, 5) << endl;
// The output is shown below. It is the confusion matrix which describes the results. Each row
// corresponds to a class of data and each column to a prediction. Reading from top to bottom,
// the rows correspond to the class labels if the labels have been listed in sorted order. So the
// top row corresponds to class 1, the middle row to class 2, and the bottom row to class 3. The
// columns are organized similarly, with the left most column showing how many samples were predicted
// as members of class 1.
//
// So in the results below we can see that, for the class 1 samples, 60 of them were correctly predicted
// to be members of class 1 and 0 were incorrectly classified. Similarly, the other two classes of data
// are perfectly classified.
/*
cross validation:
60 0 0
0 70 0
0 0 80
*/
// Next, if you wanted to obtain the decision rule learned by a one_vs_one_trainer you
// would store it into a one_vs_one_decision_function.
one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels);
cout << "predicted label: "<< df(samples[0]) << ", true label: "<< labels[0] << endl;
cout << "predicted label: "<< df(samples[90]) << ", true label: "<< labels[90] << endl;
// The output is:
/*
predicted label: 2, true label: 2
predicted label: 1, true label: 1
*/
// If you want to save a one_vs_one_decision_function to disk, you can do
// so. However, you must declare what kind of decision functions it contains.
one_vs_one_decision_function<ovo_trainer,
decision_function<poly_kernel>, // This is the output of the poly_trainer
decision_function<rbf_kernel> // This is the output of the rbf_trainer
> df2, df3;
// Put df into df2 and then save df2 to disk. Note that we could have also said
// df2 = trainer.train(samples, labels); But doing it this way avoids retraining.
df2 = df;
serialize("df.dat") << df2;
// load the function back in from disk and store it in df3.
deserialize("df.dat") >> df3;
// Test df3 to see that this worked.
cout << endl;
cout << "predicted label: "<< df3(samples[0]) << ", true label: "<< labels[0] << endl;
cout << "predicted label: "<< df3(samples[90]) << ", true label: "<< labels[90] << endl;
// Test df3 on the samples and labels and print the confusion matrix.
cout << "test deserialized function: \n" << test_multiclass_decision_function(df3, samples, labels) << endl;
// Finally, if you want to get the binary classifiers from inside a multiclass decision
// function you can do it by calling get_binary_decision_functions() like so:
one_vs_one_decision_function<ovo_trainer>::binary_function_table functs;
functs = df.get_binary_decision_functions();
cout << "number of binary decision functions in df: " << functs.size() << endl;
// The functs object is a std::map which maps pairs of labels to binary decision
// functions. So we can access the individual decision functions like so:
decision_function<poly_kernel> df_1_2 = any_cast<decision_function<poly_kernel> >(functs[make_unordered_pair(1,2)]);
decision_function<rbf_kernel> df_1_3 = any_cast<decision_function<rbf_kernel> >(functs[make_unordered_pair(1,3)]);
// df_1_2 contains the binary decision function that votes for class 1 vs. 2.
// Similarly, df_1_3 contains the classifier that votes for 1 vs. 3.
// Note that the multiclass decision function doesn't know what kind of binary
// decision functions it contains. So we have to use any_cast to explicitly cast
// them back into the concrete type. If you make a mistake and try to any_cast a
// binary decision function into the wrong type of function any_cast will throw a
// bad_any_cast exception.
}
catch (std::exception& e)
{
cout << "exception thrown!" << endl;
cout << e.what() << endl;
}
}
// ----------------------------------------------------------------------------------------
void generate_data (
std::vector<sample_type>& samples,
std::vector<double>& labels
)
{
const long num = 50;
sample_type m;
dlib::rand rnd;
// make some samples near the origin
double radius = 0.5;
for (long i = 0; i < num+10; ++i)
{
double sign = 1;
if (rnd.get_random_double() < 0.5)
sign = -1;
m(0) = 2*radius*rnd.get_random_double()-radius;
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
// add this sample to our set of training samples
samples.push_back(m);
labels.push_back(1);
}
// make some samples in a circle around the origin but far away
radius = 10.0;
for (long i = 0; i < num+20; ++i)
{
double sign = 1;
if (rnd.get_random_double() < 0.5)
sign = -1;
m(0) = 2*radius*rnd.get_random_double()-radius;
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
// add this sample to our set of training samples
samples.push_back(m);
labels.push_back(2);
}
// make some samples in a circle around the point (25,25)
radius = 4.0;
for (long i = 0; i < num+30; ++i)
{
double sign = 1;
if (rnd.get_random_double() < 0.5)
sign = -1;
m(0) = 2*radius*rnd.get_random_double()-radius;
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
// translate this point away from the origin
m(0) += 25;
m(1) += 25;
// add this sample to our set of training samples
samples.push_back(m);
labels.push_back(3);
}
}
// ----------------------------------------------------------------------------------------
|