File: trainingcontext.h

package info (click to toggle)
itksnap 3.4.0-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 10,196 kB
  • ctags: 9,196
  • sloc: cpp: 62,895; sh: 175; makefile: 13
file content (140 lines) | stat: -rwxr-xr-x 3,618 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/**
 * Define different traning context inherited from abstract TrainingContext class
 * by instantiating different statistics and classifiers.
 */

#ifndef TRAININGCONTEXT_H
#define TRAININGCONTEXT_H

#include "classifier.h"
#include "statistics.h"

struct TrainingParameters
{
  size_t treeNum;
  size_t treeDepth;
  size_t candidateNodeClassifierNum;
  size_t candidateClassifierThresholdNum;
  std::vector<double> weights;
  double subSamplePercent;
  double splitIG;
  double leafEntropy;
  bool verbose;
};

template<class S, class C>
class TrainingContext
{
public:
  // randomly get a classifier
  virtual C RandomClassifier(Random& random) = 0;

  // get an object of statistics
  virtual S Statistics() = 0;

  // compute information gain
  virtual double ComputeIG(S& parent, S& leftChild, S& rightChild,
                           const std::vector<double>& weights) = 0;
};

template<class C, class dataT, class labelT>
class ClassificationContext : public TrainingContext<Histogram<dataT, labelT>, C>
{
public:
  ClassificationContext(int featureDim, int classNum): classNum_(classNum)
  {
    classifier_ = new C(featureDim);
  }
  ~ClassificationContext()
  {
    delete classifier_;
  }

  //// law of four here???

  C RandomClassifier(Random &random)
  {
    return classifier_->RandomClassifier(random);
  }

  Histogram<dataT, labelT> Statistics()
  {
    return Histogram<dataT, labelT>(classNum_);
  }

  double ComputeIG(Histogram<dataT, labelT>& parent,
                   Histogram<dataT, labelT>& leftChild,
                   Histogram<dataT, labelT>& rightChild,
                   const std::vector<double>& weights)
  {
    std::size_t pSampleNum = parent.sampleNum_;
    std::size_t lSampleNum = leftChild.sampleNum_;
    std::size_t rSampleNum = rightChild.sampleNum_;
    if ((lSampleNum == 0) || (rSampleNum == 0))
      {
        return 0.0;
      }
    else if ((lSampleNum + rSampleNum) != pSampleNum) {
        throw std::runtime_error("ComputeIG sampleNum error!");
      }
    else
      {
        return (parent.Entropy(weights) -
                (lSampleNum * leftChild.Entropy(weights) + rSampleNum * rightChild.Entropy(weights))/pSampleNum);
      }
  }

  C* classifier_;
  std::size_t classNum_;
};

template<class C, class dataT, class labelT>
class DensityEstimationContext : public TrainingContext<GaussianStat<dataT, labelT>, C>
{
public:
  DensityEstimationContext(int featureDim): featureDim_(featureDim)
  {
    classifier_ = new C(featureDim);
  }
  ~DensityEstimationContext()
  {
    delete classifier_;
  }

  C RandomClassifier(Random &random)
  {
    return classifier_->RandomClassifier(random);
  }

  GaussianStat<dataT, labelT> Statistics()
  {
    return GaussianStat<dataT, labelT>(featureDim_);
  }

  double ComputeIG(Histogram<dataT, labelT>& parent,
                   Histogram<dataT, labelT>& leftChild,
                   Histogram<dataT, labelT>& rightChild,
                   const std::vector<double>& weights)
  {
    std::size_t pSampleNum = parent.sampleNum_;
    std::size_t lSampleNum = leftChild.sampleNum_;
    std::size_t rSampleNum = rightChild.sampleNum_;
    if ((lSampleNum == 0) || (rSampleNum == 0))
      {
        return 0.0;
      }
    else if ((lSampleNum + rSampleNum) != pSampleNum) {
        throw std::runtime_error("ComputeIG sampleNum error!");
      }
    else
      {
        return (parent.Entropy(weights) -
                (lSampleNum * leftChild.Entropy(weights) + rSampleNum * rightChild.Entropy(weights))/pSampleNum);
      }
  }

  C* classifier_;
  std::size_t featureDim_;
};

#endif // TRAININGCONTEXT_H