File: _partition_nodes.pyx

package info (click to toggle)
scikit-learn 1.2.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 23,280 kB
  • sloc: python: 184,491; cpp: 5,783; ansic: 854; makefile: 307; sh: 45; javascript: 1
file content (120 lines) | stat: -rw-r--r-- 4,082 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# BinaryTrees rely on partial sorts to partition their nodes during their
# initialisation.
#
# The C++ std library exposes nth_element, an efficient partial sort for this
# situation which has a linear time complexity as well as the best performances.
#
# To use std::algorithm::nth_element, a few fixture are defined using Cython:
# - partition_node_indices, a Cython function used in BinaryTrees, that calls
# - partition_node_indices_inner, a C++ function that wraps nth_element and uses
# - an IndexComparator to state how to compare KDTrees' indices
#
# IndexComparator has been defined so that partial sorts are stable with
# respect to the nodes initial indices.
#
# See for reference:
#  - https://en.cppreference.com/w/cpp/algorithm/nth_element.
#  - https://github.com/scikit-learn/scikit-learn/pull/11103
#  - https://github.com/scikit-learn/scikit-learn/pull/19473

cdef extern from *:
    """
    #include <algorithm>

    template<class D, class I>
    class IndexComparator {
    private:
        const D *data;
        I split_dim, n_features;
    public:
        IndexComparator(const D *data, const I &split_dim, const I &n_features):
            data(data), split_dim(split_dim), n_features(n_features) {}

        bool operator()(const I &a, const I &b) const {
            D a_value = data[a * n_features + split_dim];
            D b_value = data[b * n_features + split_dim];
            return a_value == b_value ? a < b : a_value < b_value;
        }
    };

    template<class D, class I>
    void partition_node_indices_inner(
        const D *data,
        I *node_indices,
        const I &split_dim,
        const I &split_index,
        const I &n_features,
        const I &n_points) {
        IndexComparator<D, I> index_comparator(data, split_dim, n_features);
        std::nth_element(
            node_indices,
            node_indices + split_index,
            node_indices + n_points,
            index_comparator);
    }
    """
    void partition_node_indices_inner[D, I](
                D *data,
                I *node_indices,
                I split_dim,
                I split_index,
                I n_features,
                I n_points) except +


cdef int partition_node_indices(
        DTYPE_t *data,
        ITYPE_t *node_indices,
        ITYPE_t split_dim,
        ITYPE_t split_index,
        ITYPE_t n_features,
        ITYPE_t n_points) except -1:
    """Partition points in the node into two equal-sized groups.

    Upon return, the values in node_indices will be rearranged such that
    (assuming numpy-style indexing):

        data[node_indices[0:split_index], split_dim]
          <= data[node_indices[split_index], split_dim]

    and

        data[node_indices[split_index], split_dim]
          <= data[node_indices[split_index:n_points], split_dim]

    The algorithm is essentially a partial in-place quicksort around a
    set pivot.

    Parameters
    ----------
    data : double pointer
        Pointer to a 2D array of the training data, of shape [N, n_features].
        N must be greater than any of the values in node_indices.
    node_indices : int pointer
        Pointer to a 1D array of length n_points.  This lists the indices of
        each of the points within the current node.  This will be modified
        in-place.
    split_dim : int
        the dimension on which to split.  This will usually be computed via
        the routine ``find_node_split_dim``.
    split_index : int
        the index within node_indices around which to split the points.
    n_features: int
        the number of features (i.e columns) in the 2D array pointed by data.
    n_points : int
        the length of node_indices. This is also the number of points in
        the original dataset.
    Returns
    -------
    status : int
        integer exit status.  On return, the contents of node_indices are
        modified as noted above.
    """
    partition_node_indices_inner(
        data,
        node_indices,
        split_dim,
        split_index,
        n_features,
        n_points)
    return 0