1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
|
#ifdef __GNUC__
# pragma GCC diagnostic ignored "-Wmissing-declarations"
# if defined __clang__ || defined __APPLE__
# pragma GCC diagnostic ignored "-Wmissing-prototypes"
# pragma GCC diagnostic ignored "-Wextra"
# endif
#endif
#ifndef __OPENCV_TEST_PRECOMP_HPP__
#define __OPENCV_TEST_PRECOMP_HPP__
#include "opencv2/ts/ts.hpp"
#include "opencv2/ml/ml.hpp"
#include "opencv2/core/core_c.h"
#include <iostream>
#include <map>
#define CV_NBAYES "nbayes"
#define CV_KNEAREST "knearest"
#define CV_SVM "svm"
#define CV_EM "em"
#define CV_ANN "ann"
#define CV_DTREE "dtree"
#define CV_BOOST "boost"
#define CV_RTREES "rtrees"
#define CV_ERTREES "ertrees"
class CV_MLBaseTest : public cvtest::BaseTest
{
public:
CV_MLBaseTest( const char* _modelName );
virtual ~CV_MLBaseTest();
protected:
virtual int read_params( CvFileStorage* fs );
virtual void run( int startFrom );
virtual int prepare_test_case( int testCaseIdx );
virtual std::string& get_validation_filename();
virtual int run_test_case( int testCaseIdx ) = 0;
virtual int validate_test_results( int testCaseIdx ) = 0;
int train( int testCaseIdx );
float get_error( int testCaseIdx, int type, std::vector<float> *resp = 0 );
void save( const char* filename );
void load( const char* filename );
CvMLData data;
std::string modelName, validationFN;
std::vector<std::string> dataSetNames;
cv::FileStorage validationFS;
// MLL models
CvNormalBayesClassifier* nbayes;
CvKNearest* knearest;
CvSVM* svm;
CvANN_MLP* ann;
CvDTree* dtree;
CvBoost* boost;
CvRTrees* rtrees;
CvERTrees* ertrees;
std::map<int, int> cls_map;
int64 initSeed;
};
class CV_AMLTest : public CV_MLBaseTest
{
public:
CV_AMLTest( const char* _modelName );
protected:
virtual int run_test_case( int testCaseIdx );
virtual int validate_test_results( int testCaseIdx );
};
class CV_SLMLTest : public CV_MLBaseTest
{
public:
CV_SLMLTest( const char* _modelName );
protected:
virtual int run_test_case( int testCaseIdx );
virtual int validate_test_results( int testCaseIdx );
std::vector<float> test_resps1, test_resps2; // predicted responses for test data
std::string fname1, fname2;
};
#endif
|