File: node-gini.hpp

package info (click to toggle)
cgal 6.1.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 144,952 kB
  • sloc: cpp: 811,597; ansic: 208,576; sh: 493; python: 411; makefile: 286; javascript: 174
file content (114 lines) | stat: -rw-r--r-- 4,334 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
// Copyright (c) 2014 Stefan Walk
//
// This file is part of CGAL (www.cgal.org).
//
// $URL: https://github.com/CGAL/cgal/blob/v6.1.1/Classification/include/CGAL/Classification/ETHZ/internal/random-forest/node-gini.hpp $
// $Id: include/CGAL/Classification/ETHZ/internal/random-forest/node-gini.hpp 08b27d3db14 $
// SPDX-License-Identifier: LicenseRef-RFL
// License notice in Installation/LICENSE.RFL
//
// Author(s)     : Stefan Walk

// Modifications from original library:
//  * changed inclusion protection tag
//  * moved to namespace CGAL::internal::

//  * improve sorting algorithm by only comparing the first of pair
//    (second is useless)



#ifndef CGAL_INTERNAL_LIBLEARNING_RANDOMFOREST_NODE_GINI_H
#define CGAL_INTERNAL_LIBLEARNING_RANDOMFOREST_NODE_GINI_H
#include "node.hpp"
#include "common-libraries.hpp"

namespace CGAL { namespace internal {

namespace liblearning {
namespace RandomForest {

/*
template <typename T>
class X : Y<X> {}
-> https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern#Static_polymorphism
*/

template <typename Splitter>
class NodeGini : public Node< NodeGini<Splitter>, ForestParams, Splitter > {
public:
    typedef typename Node< NodeGini<Splitter>, ForestParams, Splitter>::ParamType ParamType;
    typedef typename Splitter::FeatureType FeatureType;
    typedef typename Splitter::FeatureClassData FeatureClassData;
    using Node< NodeGini<Splitter>, ForestParams, Splitter>::params;
    NodeGini() {}
    NodeGini(int depth, ParamType const* params) :
        Node< NodeGini<Splitter>, ForestParams, Splitter>(depth, params)
    {
    }

    uint64_t gini_square_term(std::vector<uint64_t> const& frequencies) const
    {
        return std::inner_product( frequencies.begin(), frequencies.end(), frequencies.begin(), uint64_t(0));
    }
    std::pair<FeatureType, float> determine_best_threshold(FeatureClassData& data_points,
                                                     std::vector<uint64_t>& classes_l,
                                                     std::vector<uint64_t>& classes_r,
                                                     RandomGen&        gen)
    {
        double best_loss = std::numeric_limits<double>::infinity();
        float best_thresh = 0;

        UnitDist fraction_dist;
        classes_l.assign(params->n_classes, 0);
        classes_r.assign(params->n_classes, 0);
        double n_l = 0;
        double n_r = 0;
        for (size_t i_sample = 0; i_sample < data_points.size(); ++i_sample) {
            classes_r[data_points[i_sample].second]++;
            n_r += 1;
        }
        // sort data so thresholding is easy based on position in array
        std::sort(data_points.begin(), data_points.end(),
                  [&](const std::pair<float, int>& a,
                      const std::pair<float, int>& b) -> bool
                  {
                    return a.first < b.first;
                  });

        // loop over data, update class distributions left&right
        for (size_t i_point = 1; i_point < data_points.size(); ++i_point) {
            int cls = data_points[i_point-1].second;
            classes_l[cls]++; // sample with class cls gets moved to left ...
            classes_r[cls]--; // from right
            n_l += 1;
            n_r -= 1;
            // don't split here if values are the same
            if (data_points[i_point-1].first == data_points[i_point].first)
                continue;
            // weighted average
            double gini = n_l - gini_square_term(classes_l) * 1.0 / n_l + n_r - gini_square_term(classes_r) * 1.0 / n_r;
            if (gini < best_loss) {
                best_loss = gini;
                double fraction = fraction_dist(gen);
                best_thresh = fraction * data_points[i_point-1].first + (1-fraction) * data_points[i_point].first;
            }
        }
        return std::make_pair(best_thresh, float(best_loss));
    }

#if defined(CGAL_LINKED_WITH_BOOST_IOSTREAMS) && defined(CGAL_LINKED_WITH_BOOST_SERIALIZATION)
    template <typename Archive>
    void serialize(Archive& ar, unsigned /* version */)
    {
        ar & boost::serialization::make_nvp("base",  boost::serialization::base_object< Node< NodeGini<Splitter>, ForestParams, Splitter > >(*this));
    }
#endif
};

}
}

}} // namespace CGAL::internal::

#endif