File: cross_validate_multiclass_trainer_abstract.h

package info (click to toggle)
mldemos 0.5.1-3
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 32,224 kB
  • ctags: 46,525
  • sloc: cpp: 306,887; ansic: 167,718; ml: 126; sh: 109; makefile: 2
file content (99 lines) | stat: -rw-r--r-- 4,127 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#undef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_ABSTRACT_H__
#ifdef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_ABSTRACT_H__

#include <vector>
#include "../matrix.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename dec_funct_type,
        typename sample_type,
        typename label_type
        >
    const matrix<double> test_multiclass_decision_function (
        const dec_funct_type& dec_funct,
        const std::vector<sample_type>& x_test,
        const std::vector<label_type>& y_test
    );
    /*!
        requires
            - is_learning_problem(x_test, y_test)
            - dec_funct_type == some kind of multiclass decision function object 
              (e.g. one_vs_one_decision_function)
        ensures
            - Tests dec_funct against the given samples in x_test and labels in y_test
              and returns a confusion matrix summarizing the results.
            - let L = dec_funct.get_labels().  Then the confusion matrix C returned 
              by this function has the following properties.
                - C.nr() == C.nc() == L.size()
                - C(r,c) == the number of times a sample with label L(r) was predicted
                  to have a label of L(c)
            - Any samples with a y_test value not in L are ignored.  That is, samples
              with labels the decision function hasn't ever seen before are ignored.
    !*/

// ----------------------------------------------------------------------------------------

    class cross_validation_error : public dlib::error 
    { 
        /*!
            This is the exception class used by the cross_validate_multiclass_trainer() 
            routine.
        !*/
    };

// ----------------------------------------------------------------------------------------

    template <
        typename trainer_type,
        typename sample_type,
        typename label_type 
        >
    const matrix<double> cross_validate_multiclass_trainer (
        const trainer_type& trainer,
        const std::vector<sample_type>& x,
        const std::vector<label_type>& y,
        const long folds
    );
    /*!
        requires
            - is_learning_problem(x,y)
            - 1 < folds <= x.size()
            - trainer_type == some kind of multiclass classification trainer object (e.g. one_vs_one_trainer)
        ensures
            - performs k-fold cross validation by using the given trainer to solve the
              given multiclass classification problem for the given number of folds.
              Each fold is tested using the output of the trainer and the confusion
              matrix from all folds is summed and returned.
            - The total confusion matrix is computed by running test_binary_decision_function()
              on each fold and summing its output.
            - The number of folds used is given by the folds argument.
            - let L = select_all_distinct_labels(y).  Then the confusion matrix C returned 
              by this function has the following properties.
                - C.nr() == C.nc() == L.size()
                - C(r,c) == the number of times a sample with label L(r) was predicted
                  to have a label of L(c)

              Note that sum(C) might be slightly less than x.size().  This happens if the number of 
              samples in a class is not an even multiple of folds.  This is because each fold has the 
              same number of test samples in it and so if the number of samples in a class isn't a 
              multiple of folds then a few are not tested.  
        throws
            - cross_validation_error
              This exception is thrown if one of the classes has fewer samples than
              the number of requested folds.
    !*/

}

// ----------------------------------------------------------------------------------------

#endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_ABSTRACT_H__