File: tree_engine.cpp

package info (click to toggle)
opencv 2.1.0-3%2Bsqueeze1
  • links: PTS, VCS
  • area: main
  • in suites: squeeze
  • size: 68,800 kB
  • ctags: 52,010
  • sloc: cpp: 554,793; xml: 475,942; ansic: 153,396; python: 18,622; sh: 428; makefile: 111
file content (79 lines) | stat: -rw-r--r-- 2,463 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include "ml.h"
#include <stdio.h>
/*
The sample demonstrates how to use different decision trees.
*/
void print_result(float train_err, float test_err, const CvMat* var_imp)
{
    printf( "train error    %f\n", train_err );
    printf( "test error    %f\n\n", test_err );
       
    if (var_imp)
    {
        bool is_flt = false;
        if ( CV_MAT_TYPE( var_imp->type ) == CV_32FC1)
            is_flt = true;
        printf( "variable impotance\n" );
        for( int i = 0; i < var_imp->cols; i++)
        {
            printf( "%d     %f\n", i, is_flt ? var_imp->data.fl[i] : var_imp->data.db[i] );
        }
    }
    printf("\n");
}

int main()
{
    const int train_sample_count = 300;

//#define LEPIOTA
#ifdef LEPIOTA
    const char* filename = "../../../OpenCV/samples/c/agaricus-lepiota.data";
#else
    const char* filename = "../../../OpenCV/samples/c/waveform.data";
#endif

    CvDTree dtree;
    CvBoost boost;
    CvRTrees rtrees;
    CvERTrees ertrees;

    CvMLData data;

    CvTrainTestSplit spl( train_sample_count );
    
    if ( data.read_csv( filename ) == 0)
    {

#ifdef LEPIOTA
        data.set_response_idx( 0 );     
#else
        data.set_response_idx( 21 );     
        data.change_var_type( 21, CV_VAR_CATEGORICAL );
#endif

        data.set_train_test_split( &spl );
        
        printf("======DTREE=====\n");
        dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
        print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );

#ifdef LEPIOTA
        printf("======BOOST=====\n");
        boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
        print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data ), 0 );
#endif

        printf("======RTREES=====\n");
        rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
        print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );

        printf("======ERTREES=====\n");
        ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
        print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
    }
    else
        printf("File can not be read");

    return 0;
}