1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
|
const char *help = "\
GMM (c) Samy Bengio & Co 2001\n\
\n\
This program will maximize the likelihood of data given a Diagonal GMM \n";
#include "EMTrainer.h"
#include "DiagonalGMM.h"
#include "Kmeans.h"
#include "MatSeqDataSet.h"
#include "CmdLine.h"
#include "NllMeasurer.h"
using namespace Torch;
int main(int argc, char **argv)
{
char *train_file;
int max_load;
int seed_value;
real accuracy;
real threshold;
int max_iter_kmeans;
int max_iter_gmm;
char *dir_name;
int n_gaussians;
real prior;
char *load_model;
char *save_model;
// Construct the command line
CmdLine cmd;
// Put the help line at the beginning
cmd.info(help);
// Ask for arguments
cmd.addText("\nArguments:");
cmd.addSCmdArg("file", &train_file, "the train file");
// Propose some options
cmd.addText("\nModel Options:");
cmd.addICmdOption("-n_gaussians", &n_gaussians, 10, "number of Gaussians");
cmd.addRCmdOption("-threshold", &threshold, 0.0001, "variance threshold");
cmd.addRCmdOption("-prior", &prior, 0.001, "prior on the weights");
cmd.addText("\nLearning Options:");
cmd.addICmdOption("-iterk", &max_iter_kmeans, 25, "max number of iterations of Kmeans");
cmd.addICmdOption("-iterg", &max_iter_gmm, 25, "max number of iterations of GMM");
cmd.addRCmdOption("-e", &accuracy, 0.0001, "end accuracy");
cmd.addText("\nMisc Options:");
cmd.addICmdOption("-load", &max_load, -1, "max number of examples to load");
cmd.addICmdOption("-seed", &seed_value, -1, "initial seed for random generator");
cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures");
cmd.addSCmdOption("-lm", &load_model, "", "start from given model file");
cmd.addSCmdOption("-sm", &save_model, "", "save results into given model file");
// Read the command line
cmd.read(argc, argv);
// If the user didn't give any random seed,
// generate a random random seed...
if (seed_value == -1)
seed();
else
manual_seed((long)seed_value);
// load the data (each line is a example with 1 frame)
MatSeqDataSet data(train_file, 0,-1,0,false, max_load);
data.init();
data.toOneFramePerExample();
int n_observations = data.n_observations;
// create the variance threshold vector for Kmeans and GMM
real* thresh = (real*)xalloc(n_observations*sizeof(real));
for (int i=0;i<n_observations;i++)
thresh[i] = threshold;
// create a Kmeans object to initialize the GMM
Kmeans kmeans(n_observations,n_gaussians,thresh,prior,&data);
kmeans.init();
kmeans.reset();
// create a trainer to trainer the Kmeans
EMTrainer* kmeans_trainer = new EMTrainer(&kmeans,&data);
kmeans_trainer->setROption("end accuracy", accuracy);
kmeans_trainer->setIOption("max iter", max_iter_kmeans);
// create a measurer to measure the iterative performance of Kmeans
List* ptr_meas_kmeans[1];
ptr_meas_kmeans[0] = NULL;
char kmeans_name[100];
sprintf(kmeans_name,"%s/kmeans_val",dir_name);
NllMeasurer vec_meas_kmeans(kmeans.outputs,&data,kmeans_name);
vec_meas_kmeans.init();
addToList(&ptr_meas_kmeans[0],1,&vec_meas_kmeans);
// create a GMM either from the kmeans parameters of from file
DiagonalGMM* gmm;
char *load_model_name = (char*)xalloc(sizeof(char)*(strlen(dir_name)+strlen(load_model)+2));
if (!strcmp(load_model,"")) {
gmm = new DiagonalGMM(n_observations,n_gaussians,thresh,prior);
gmm->setOption("initial kmeans trainer",&kmeans_trainer);
gmm->setOption("initial kmeans trainer measurers",&ptr_meas_kmeans);
} else {
sprintf(load_model_name,"%s/%s",dir_name,load_model);
gmm = new DiagonalGMM(n_observations,n_gaussians,thresh,prior);
gmm->setOption("initial file",&load_model_name);
}
gmm->init();
gmm->reset();
// create the EM trainer to train the GMM
EMTrainer trainer(gmm,&data);
trainer.setROption("end accuracy", accuracy);
trainer.setIOption("max iter", max_iter_gmm);
// create a measurer to measure the negative log likelihood of the GMM
List *meas_gmm = NULL;
char gmm_name[100];
sprintf(gmm_name,"%s/gmm_val",dir_name);
NllMeasurer vec_meas_gmm(gmm->outputs,&data,gmm_name);
vec_meas_gmm.init();
addToList(&meas_gmm,1,&vec_meas_gmm);
// either train or test the GMM
if (strcmp(load_model,"")) {
trainer.test(meas_gmm);
} else {
trainer.train(meas_gmm);
}
// eventually, save the parameters of the model
if (strcmp(save_model,"")) {
char save_model_name[100];
sprintf(save_model_name,"%s/%s",dir_name,save_model);
trainer.save(save_model_name);
}
// free all the allocated memory
free(load_model_name);
free(thresh);
freeList(&ptr_meas_kmeans[0]);
freeList(&meas_gmm);
delete gmm;
delete kmeans_trainer;
return(0);
}
|