1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
|
//
// forest.cpp
// Mothur
//
// Created by Kathryn Iverson on 10/26/12.
// Copyright (c) 2012 Schloss Lab. All rights reserved.
//
#include "forest.h"
/***********************************************************************/
Forest::Forest(const std::vector < std::vector<int> > dataSet,
const int numDecisionTrees,
const string treeSplitCriterion = "gainratio",
const bool doPruning = false,
const float pruneAggressiveness = 0.9,
const bool discardHighErrorTrees = true,
const float highErrorTreeDiscardThreshold = 0.4,
const string optimumFeatureSubsetSelectionCriteria = "log2",
const float featureStandardDeviationThreshold = 0.0)
: dataSet(dataSet),
numDecisionTrees(numDecisionTrees),
numSamples((int)dataSet.size()),
numFeatures((int)(dataSet[0].size() - 1)),
globalVariableImportanceList(numFeatures, 0),
treeSplitCriterion(treeSplitCriterion),
doPruning(doPruning),
pruneAggressiveness(pruneAggressiveness),
discardHighErrorTrees(discardHighErrorTrees),
highErrorTreeDiscardThreshold(highErrorTreeDiscardThreshold),
optimumFeatureSubsetSelectionCriteria(optimumFeatureSubsetSelectionCriteria),
featureStandardDeviationThreshold(featureStandardDeviationThreshold)
{
m = MothurOut::getInstance();
globalDiscardedFeatureIndices = getGlobalDiscardedFeatureIndices();
// TODO: double check if the implemenatation of 'globalOutOfBagEstimates' is correct
}
/***********************************************************************/
vector<int> Forest::getGlobalDiscardedFeatureIndices() {
try {
//vector<int> globalDiscardedFeatureIndices;
//globalDiscardedFeatureIndices.push_back(1);
// calculate feature vectors
vector< vector<int> > featureVectors(numFeatures, vector<int>(numSamples, 0) );
for (int i = 0; i < numSamples; i++) {
if (m->control_pressed) { return globalDiscardedFeatureIndices; }
for (int j = 0; j < numFeatures; j++) { featureVectors[j][i] = dataSet[i][j]; }
}
for (int i = 0; i < featureVectors.size(); i++) {
if (m->control_pressed) { return globalDiscardedFeatureIndices; }
double standardDeviation = m->getStandardDeviation(featureVectors[i]);
if (standardDeviation <= featureStandardDeviationThreshold){ globalDiscardedFeatureIndices.push_back(i); }
}
if (m->debug) {
m->mothurOut("number of global discarded features: " + toString(globalDiscardedFeatureIndices.size())+ "\n");
m->mothurOut("total features: " + toString(featureVectors.size())+ "\n");
}
return globalDiscardedFeatureIndices;
}
catch(exception& e) {
m->errorOut(e, "Forest", "getGlobalDiscardedFeatureIndices");
exit(1);
}
}
/***********************************************************************/
|