1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
|
const char *help = "\
HMM (c) Trebolloc & Co 2001\n\
\n\
This program will train a HMM \n";
#include "EMTrainer.h"
#include "DiagonalGMM.h"
#include "Kmeans.h"
#include "HMM.h"
#include "MatSeqDataSet.h"
#include "CmdLine.h"
#include "NllMeasurer.h"
using namespace Torch;
int main(int argc, char **argv)
{
char *train_file;
char *test_file;
int max_load_train;
int max_load_test;
int seed_value;
real accuracy;
real threshold;
int max_iter_kmeans;
int max_iter_hmm;
char *dir_name;
char *load_model;
char *save_model;
int n_gaussians;
int n_states;
real prior;
bool left_right;
// Construct the command line
CmdLine cmd;
// Construct the command line
cmd.info(help);
// Propose some options
cmd.addText("\nArguments:");
cmd.addSCmdArg("train_file", &train_file, "the train files, in double-quote");
cmd.addSCmdArg("test_file", &test_file, "the test files, in double-quote");
cmd.addText("\nModel Options:");
cmd.addICmdOption("-n_gaussians", &n_gaussians, 10, "number of Gaussians");
cmd.addICmdOption("-n_states", &n_states, 5, "number of states");
cmd.addRCmdOption("-threshold", &threshold, 0.001, "stdev threshold");
cmd.addRCmdOption("-prior", &prior, 0.001, "prior on the weights and transitions");
cmd.addBCmdOption("-left_right", &left_right, false, "left-right connectivity (otherwise: full-connect)");
cmd.addText("\nLearning Options:");
cmd.addICmdOption("-iterk", &max_iter_kmeans, 25, "max number of iterations of Kmeans");
cmd.addICmdOption("-iterg", &max_iter_hmm, 25, "max number of iterations of HMM");
cmd.addRCmdOption("-e", &accuracy, 0.0001, "end accuracy");
cmd.addText("\nMisc Options:");
cmd.addICmdOption("-load_train", &max_load_train, -1, "max number of train examples to load");
cmd.addICmdOption("-load_test", &max_load_test, -1, "max number of test examples to load");
cmd.addICmdOption("-seed", &seed_value, -1, "initial seed for random generator");
cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures");
cmd.addSCmdOption("-lm", &load_model, "", "start from given model file");
cmd.addSCmdOption("-sm", &save_model, "", "save results into given model file");
// Read the command line
cmd.read(argc, argv);
// If the user didn't give any random seed,
// generate a random random seed...
if (seed_value == -1)
seed();
else
manual_seed((long)seed_value);
// load the train data (each file is a sequence)
char* train_files[1000];
int n;
train_files[0] = strtok(train_file," ");
for (n=1;(train_files[n] = strtok(NULL," "));n++);
if (n!=1) {
n = max_load_train > 0 && max_load_train < n ? max_load_train : n;
max_load_train = -1;
}
MatSeqDataSet data(train_files, n, 0,-1,0,false, max_load_train);
data.init();
int n_observations = data.n_observations;
// load the test data (each file is a sequence)
char* test_files[1000];
test_files[0] = strtok(test_file," ");
for (n=1;(test_files[n] = strtok(NULL," "));n++);
if (n!=1) {
n = max_load_test > 0 && max_load_test < n ? max_load_test : n;
max_load_test = -1;
}
MatSeqDataSet test_data(test_files, n, 0,-1,0,false,max_load_test);
test_data.init();
// for each state of the HMM, we need a GMM, a Kmeans to intialize it
// and a trainer to train the kmeans.
Kmeans** kmeans = new Kmeans *[n_states];
DiagonalGMM** gmm = new DiagonalGMM *[n_states];
EMTrainer** kmeans_trainer = new EMTrainer *[n_states];
// each gaussians of the GMMs inside the HMM might have a variance threshold
real* thresh = (real*)xalloc(n_observations*sizeof(real));
for (int i=0;i<n_observations;i++)
thresh[i] = threshold;
// this stuff is to initialize the GMMs of the HMM. There are two cases:
// either we select the "left-right" topology, and in that case we will
// initialize the data using a linear segmentation through the sequences.
// Or we select the "full-connect" topology and in that case, we select
// randomly vectors for each state (for initialization purposes only).
data.totNFrames();
n = data.n_examples;
// for each state, we will select its frames for the Kmeans initialization
for (int i=1;i<n_states-1;i++) {
if (left_right) {
// linear segmentation
data.linearSegmentation(i,n_states);
} else {
// select a bootstrap
data.selectBootstrap();
}
// for each state, we create a Kmeans
kmeans[i] = new Kmeans(n_observations,n_gaussians,thresh,prior,&data);
kmeans[i]->init();
// ... and a trainer for this kmeans
kmeans_trainer[i] = new EMTrainer(kmeans[i],&data);
kmeans_trainer[i]->setROption("end accuracy", accuracy);
kmeans_trainer[i]->setIOption("max iter", max_iter_kmeans);
// ... which we train
if (!strcmp(load_model,"")) {
kmeans[i]->reset();
kmeans_trainer[i]->train(NULL);
}
// push back original data
data.unsetAllSelectedFrames();
// ... then we create a GMM
gmm[i] = new DiagonalGMM(n_observations,n_gaussians,thresh,prior);
// ... which we initialize the parameters to the kmeans parameters
gmm[i]->setOption("initial params",&kmeans[i]->params);
gmm[i]->init();
}
// note that HMMs have two non-emitting states: the initial and final states
gmm[0] = NULL;
gmm[n_states-1] = NULL;
// we create the transition matrix with initial transition probabilities
real** transitions = (real**)xalloc(n_states*sizeof(real*));
for (int i=0;i<n_states;i++) {
transitions[i] = (real*)xalloc(n_states*sizeof(real));
}
for (int i=0;i<n_states;i++) {
for (int j=0;j<n_states;j++)
transitions[i][j] = 0;
}
// ... the left_right transition matrix
if (left_right) {
transitions[1][0] = 1;
for (int i=1;i<n_states-1;i++) {
transitions[i][i] = 0.5;
transitions[i+1][i] = 0.5;
}
} else {
// ... the full_connect transition matrix
for (int i=1;i<n_states-1;i++) {
transitions[i][0] = 1./(n_states-2);
for (int j=1;j<n_states;j++) {
transitions[j][i] = 1./(n_states-1);
}
}
}
// we create the HMM
HMM hmm(n_states,(Distribution**)gmm,prior,&data,transitions);
hmm.init();
hmm.reset();
// ... and its associated trainer
EMTrainer trainer(&hmm,&data);
trainer.setROption("end accuracy", accuracy);
trainer.setIOption("max iter", max_iter_hmm);
// ... as well as its associated measurers
List *meas_hmm = NULL;
// ... one for the training data
char hmm_name[100];
sprintf(hmm_name,"%s/hmm_val",dir_name);
NllMeasurer vec_meas_hmm(hmm.outputs,&data,hmm_name);
vec_meas_hmm.init();
addToList(&meas_hmm,1,&vec_meas_hmm);
// ... and one for the test data
char hmm_test_name[100];
sprintf(hmm_test_name,"%s/hmm_test_val",dir_name);
NllMeasurer vec_meas_hmm_test(hmm.outputs,&test_data,hmm_test_name);
vec_meas_hmm_test.init();
addToList(&meas_hmm,1,&vec_meas_hmm_test);
// either load the model and test it...
if (strcmp(load_model,"")) {
char load_model_name[100];
sprintf(load_model_name,"%s/%s",dir_name,load_model);
trainer.load(load_model_name);
trainer.test(meas_hmm);
} else
// or train the model
trainer.train(meas_hmm);
// perform a viterbi decoding on every sequences
// and print the best state path
for (int i=0;i<test_data.n_examples;i++) {
test_data.setExample(i);
SeqExample* ex = (SeqExample*)test_data.inputs->ptr;
printf("viterbi %d (%d): ",test_data.current_example,ex->n_frames);
hmm.logProbabilities(test_data.inputs);
hmm.logViterbi(ex);
for (int j=1;j<=ex->n_frames;j++)
printf("%d ",hmm.viterbi_sequence[j]);
printf("\n");
}
// eventually save the model
if (strcmp(save_model,"")) {
char save_model_name[100];
sprintf(save_model_name,"%s/%s",dir_name,save_model);
trainer.save(save_model_name);
}
// free all the allocated memory
for (int i=1;i<n_states-1;i++) {
delete kmeans[i];
delete gmm[i];
delete kmeans_trainer[i];
}
delete[] kmeans;
delete[] gmm;
delete[] kmeans_trainer;
for (int i=0;i<n_states;i++)
free(transitions[i]);
free(transitions);
free(thresh);
freeList(&meas_hmm);
return(0);
}
|