1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
|
<!doctype html public "-//w3c//dtd html 4.0 transitional//en">
<html>
<head>
<meta http-equiv="Content-Type" content="text/html; charset=iso-8859-1">
<meta name="GENERATOR" content="Mozilla/4.77 [en] (X11; U; SunOS 5.7 sun4u) [Netscape]">
</head>
<body text="#000000" bgcolor="#FFFFFF" link="#0000EE" vlink="#551A8B" alink="#FF0000">
<tt><font color="#CC6600">const char *help = "\</font></tt>
<br><tt><font color="#CC6600">TorchMLP\n\</font></tt>
<br><tt><font color="#CC6600">\n\</font></tt>
<br><tt><font color="#CC6600">This program will train a MLP with tanh outputs
for\n\</font></tt>
<br><tt><font color="#CC6600">classification and linear outputs for regression\n";</font></tt><tt><font color="#FF9900"></font></tt>
<p><tt><font color="#009900">#include "ConnectedMachine.h"</font></tt>
<br><tt><font color="#009900">#include "Linear.h"</font></tt>
<br><tt><font color="#009900">#include "FileDataSet.h"</font></tt>
<br><tt><font color="#009900">#include "MseCriterion.h"</font></tt>
<br><tt><font color="#009900">#include "Tanh.h"</font></tt>
<br><tt><font color="#009900">#include "MseMeasurer.h"</font></tt>
<br><tt><font color="#009900">#include "ClassMeasurer.h"</font></tt>
<br><tt><font color="#009900">#include "TwoClassFormat.h"</font></tt>
<br><tt><font color="#009900">#include "OneHotClassFormat.h"</font></tt>
<br><tt><font color="#009900">#include "StochasticGradient.h"</font></tt>
<br><tt><font color="#009900">#include "GMTrainer.h"</font></tt>
<br><tt><font color="#009900">#include "CmdLine.h"</font></tt><tt></tt>
<p><tt>int main(int argc, char **argv)</tt>
<br><tt>{</tt>
<br><tt> char *model_file, *test_model_file;</tt>
<br><tt> char *valid_file;</tt>
<br><tt> char *file;</tt><tt></tt>
<p><tt> int n_inputs;</tt>
<br><tt> int n_targets;</tt>
<br><tt> int n_hu;</tt><tt></tt>
<p><tt> int max_load;</tt>
<br><tt> real accuracy;</tt>
<br><tt> real learning_rate;</tt>
<br><tt> real decay;</tt>
<br><tt> int max_iter;</tt>
<br><tt> bool regression;</tt>
<br><tt> int k_fold;</tt>
<br><tt> int the_seed;</tt><tt></tt>
<p><tt> <font color="#FF0000">//=================== The command-line
==========================</font></tt><tt><font color="#FF0000"></font></tt>
<p><tt> <font color="#CC6600">// Construct the command line</font></tt>
<br><tt> CmdLine cmd;</tt><tt></tt>
<p><tt> <font color="#CC6600">// Put the help line at the beginning</font></tt>
<br><tt> cmd.info(help);</tt><tt></tt>
<p><tt> <font color="#CC6600">// Ask for arguments</font></tt>
<br><tt> cmd.addText("\nArguments:");</tt>
<br><tt> cmd.addSCmdArg("file", &file, "the train or test file");</tt>
<br><tt> cmd.addICmdArg("n_inputs", &n_inputs, "input dimension
of the data");</tt>
<br><tt> cmd.addICmdArg("n_targets", &n_targets, "output dimension
of the data");</tt><tt></tt>
<p><tt> <font color="#CC6600">// Propose some options</font></tt>
<br><tt> cmd.addText("\nModel Options:");</tt>
<br><tt> cmd.addICmdOption("-nhu", &n_hu, 25, "number of hidden
units");</tt>
<br><tt> cmd.addBCmdOption("-rm", &regression, false, "regression
mode");</tt><tt></tt>
<p><tt> cmd.addText("\nLearning Options:");</tt>
<br><tt> cmd.addICmdOption("-iter", &max_iter, 25, "max number
of iterations");</tt>
<br><tt> cmd.addRCmdOption("-lr", &learning_rate, 0.01, "learning
rate");</tt>
<br><tt> cmd.addRCmdOption("-e", &accuracy, 0.00001, "end accuracy");</tt>
<br><tt> cmd.addRCmdOption("-lrd", &decay, 0, "learning rate
decay");</tt><tt></tt>
<p><tt> cmd.addText("\nMisc Options:");</tt>
<br><tt> cmd.addICmdOption("-seed", &the_seed, -1, "the random
seed");</tt>
<br><tt> cmd.addICmdOption("-Kfold", &k_fold, -1, "number of
subsets for K-fold cross-validation");</tt>
<br><tt> cmd.addICmdOption("-load", &max_load, -1, "max number
of examples to load");</tt>
<br><tt> cmd.addSCmdOption("-valid", &valid_file, "", "validation
file, if you want it");</tt>
<br><tt> cmd.addSCmdOption("-sm", &model_file, "", "file to save
the model");</tt>
<br><tt> cmd.addSCmdOption("-test", &test_model_file, "", "model
file to test");</tt><tt></tt>
<p><tt> <font color="#CC6600">// Read the command line</font></tt>
<br><tt> cmd.read(argc, argv);</tt><tt></tt>
<p><tt> <font color="#CC6600">// If the user didn't give any random
seed,</font></tt>
<br><tt><font color="#CC6600"> // generate a random random seed...</font></tt>
<br><tt> if(the_seed == -1)</tt>
<br><tt> seed();</tt>
<br><tt> else</tt>
<br><tt> manual_seed((long)the_seed);</tt><tt></tt>
<p><tt> <font color="#FF0000">//=================== Create the MLP...
=========================</font></tt>
<br><tt> ConnectedMachine MLP;</tt><tt></tt>
<p><tt> <font color="#CC6600">// Create the layers of the MLP</font></tt>
<br><tt> Linear hidden_linear(n_inputs, n_hu);</tt>
<br><tt> Tanh hidden_nlinear(n_hu);</tt>
<br><tt> Linear output_linear(n_hu, n_targets);</tt>
<br><tt> Tanh output_nlinear(n_targets);</tt><tt></tt>
<p><tt> <font color="#CC6600">// Initialize the layers</font></tt>
<br><tt> hidden_linear.init();</tt>
<br><tt> hidden_nlinear.init();</tt>
<br><tt> output_linear.init();</tt>
<br><tt> output_nlinear.init();</tt><tt></tt>
<p><tt> <font color="#CC6600">// Add the layers (Full Connected Layers)
to the MLP</font></tt>
<br><tt> MLP.addFCL(&hidden_linear);</tt>
<br><tt> MLP.addFCL(&hidden_nlinear);</tt>
<br><tt> MLP.addFCL(&output_linear);</tt><tt></tt>
<p><tt> <font color="#CC6600">// If regression, don't add the tanh
output layer</font></tt>
<br><tt> if(!regression)</tt>
<br><tt> MLP.addFCL(&output_nlinear);</tt><tt></tt>
<p><tt> <font color="#CC6600">// Initialize the MLP</font></tt>
<br><tt> MLP.init();</tt>
<br><tt></tt> <tt></tt>
<p><tt> <font color="#FF0000">//=================== DataSets &
Measurers... ===================</font></tt><tt><font color="#FF0000"></font></tt>
<p><tt> <font color="#CC6600">// Create the training dataset (normalize
inputs)</font></tt>
<br><tt> FileDataSet data(file, n_inputs, n_targets, false, max_load);</tt>
<br><tt> data.setBOption("normalize inputs", true);</tt>
<br><tt> data.init();</tt><tt></tt>
<p><tt> <font color="#CC6600">// The list of measurers...</font></tt>
<br><tt> List *measurers = NULL;</tt><tt></tt>
<p><tt> <font color="#CC6600">// The class format</font></tt>
<br><tt> ClassFormat *class_format = NULL;</tt>
<br><tt> if(!regression)</tt>
<br><tt> {</tt>
<br><tt> if(n_targets == 1)</tt>
<br><tt> class_format = new TwoClassFormat(&data);</tt>
<br><tt> else</tt>
<br><tt> class_format = new OneHotClassFormat(&data);</tt>
<br><tt> }</tt><tt></tt>
<p><tt> <font color="#CC6600">// The validation set...</font></tt>
<br><tt> FileDataSet *valid_data = NULL;</tt>
<br><tt> MseMeasurer *valid_mse_meas = NULL;</tt>
<br><tt> ClassMeasurer *valid_class_meas = NULL;</tt><tt></tt>
<p><tt> <font color="#CC6600">// Create a validation set, if any</font></tt>
<br><tt> if(strcmp(valid_file, ""))</tt>
<br><tt> {</tt>
<br><tt> <font color="#CC6600">// Load the validation
set and normalize it with the</font></tt>
<br><tt><font color="#CC6600"> // values in the train
dataset</font></tt>
<br><tt> valid_data = new FileDataSet(valid_file, n_inputs,
n_targets);</tt>
<br><tt> valid_data->init();</tt>
<br><tt> valid_data->normalizeUsingDataSet(&data);</tt><tt></tt>
<p><tt> <font color="#CC6600">// Create a MSE measurer
and an error class measurer</font></tt>
<br><tt><font color="#CC6600"> // on the validation dataset
(if we are not in regression)</font></tt>
<br><tt> valid_mse_meas = new MseMeasurer(MLP.outputs,
valid_data, "the_valid_mse");</tt>
<br><tt> valid_mse_meas->init();</tt>
<br><tt> addToList(&measurers, 1, valid_mse_meas);</tt><tt></tt>
<p><tt> if(!regression)</tt>
<br><tt> {</tt>
<br><tt> valid_class_meas = new ClassMeasurer(MLP.outputs,
valid_data, class_format, "the_valid_class_err");</tt>
<br><tt> valid_class_meas->init();</tt>
<br><tt> addToList(&measurers, 1, valid_class_meas);</tt>
<br><tt> }</tt>
<br><tt> }</tt><tt></tt>
<p><tt> <font color="#CC6600">// Measurers on the training dataset</font></tt>
<br><tt> MseMeasurer *mse_meas = new MseMeasurer(MLP.outputs, &data,
"the_mse");</tt>
<br><tt> mse_meas->init();</tt>
<br><tt> addToList(&measurers, 1, mse_meas);</tt><tt></tt>
<p><tt> ClassMeasurer *class_meas = NULL;</tt>
<br><tt> if(!regression)</tt>
<br><tt> {</tt>
<br><tt> class_meas = new ClassMeasurer(MLP.outputs,
&data, class_format, "the_class_err");</tt>
<br><tt> class_meas->init();</tt>
<br><tt> addToList(&measurers, 1, class_meas);</tt>
<br><tt> }</tt><tt></tt>
<p><tt> <font color="#FF0000">//=================== The Trainer ===============================</font></tt>
<br><tt> </tt>
<br><tt> <font color="#CC6600">// The criterion for the GMTrainer
(MSE criterion)</font></tt>
<br><tt> MseCriterion mse(n_targets);</tt>
<br><tt> mse.init();</tt><tt></tt>
<p><tt> // The optimizer for the GMTrainer</tt>
<br><tt> StochasticGradient opt;</tt>
<br><tt> opt.setIOption("max iter", max_iter);</tt>
<br><tt> opt.setROption("end accuracy", accuracy);</tt>
<br><tt> opt.setROption("learning rate", learning_rate);</tt>
<br><tt> opt.setROption("learning rate decay", decay);</tt><tt></tt>
<p><tt> <font color="#CC6600">// The Gradient Machine Trainer</font></tt>
<br><tt> GMTrainer trainer(&MLP, &data, &mse, &opt);</tt><tt></tt>
<p><tt> <font color="#FF0000">//=================== Let's go... ===============================</font></tt><tt></tt>
<p><tt> <font color="#CC6600">// Print the number of parameter of
the MLP (just for fun)</font></tt>
<br><tt> message("Number of parameters: %d", MLP.n_params);</tt><tt></tt>
<p><tt> <font color="#CC6600">// If the user provides a previously
trained model,</font></tt>
<br><tt><font color="#CC6600"> // test it...</font></tt>
<br><tt> if( strcmp(test_model_file, "") )</tt>
<br><tt> {</tt>
<br><tt> trainer.load(test_model_file);</tt>
<br><tt> trainer.test(measurers);</tt>
<br><tt> }</tt><tt></tt>
<p><tt> <font color="#CC6600">// ...else...</font></tt>
<br><tt> else</tt>
<br><tt> {</tt>
<br><tt> <font color="#CC6600">// If the user provides
a number for the K-fold validation,</font></tt>
<br><tt><font color="#CC6600"> // do a K-fold validation</font></tt>
<br><tt> if(k_fold > 0)</tt>
<br><tt> trainer.crossValidate(k_fold, NULL,
measurers);</tt><tt></tt>
<p><tt> <font color="#CC6600">// Else, train the model</font></tt>
<br><tt> else</tt>
<br><tt> trainer.train(measurers);</tt><tt></tt>
<p><tt> <font color="#CC6600">// Save the model if the
user provides a name for that</font></tt>
<br><tt> if( strcmp(model_file, "") )</tt>
<br><tt> trainer.save(model_file);</tt>
<br><tt> }</tt><tt></tt>
<p><tt> <font color="#FF0000">//=================== Quit... ===================================</font></tt>
<br><tt> if(strcmp(valid_file, ""))</tt>
<br><tt> {</tt>
<br><tt> delete valid_data;</tt>
<br><tt> delete valid_mse_meas;</tt>
<br><tt> if(!regression)</tt>
<br><tt> delete valid_class_meas;</tt>
<br><tt> }</tt><tt></tt>
<p><tt> delete mse_meas;</tt>
<br><tt> if(!regression)</tt>
<br><tt> {</tt>
<br><tt> delete class_meas;</tt>
<br><tt> delete class_format;</tt>
<br><tt> }</tt><tt></tt>
<p><tt> freeList(&measurers);</tt><tt></tt>
<p><tt> return(0);</tt>
<br><tt>}</tt>
<br>
</body>
</html>
|