1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
|
# BinaryTrees rely on partial sorts to partition their nodes during their
# initialisation.
#
# The C++ std library exposes nth_element, an efficient partial sort for this
# situation which has a linear time complexity as well as the best performances.
#
# To use std::algorithm::nth_element, a few fixture are defined using Cython:
# - partition_node_indices, a Cython function used in BinaryTrees, that calls
# - partition_node_indices_inner, a C++ function that wraps nth_element and uses
# - an IndexComparator to state how to compare KDTrees' indices
#
# IndexComparator has been defined so that partial sorts are stable with
# respect to the nodes initial indices.
#
# See for reference:
# - https://en.cppreference.com/w/cpp/algorithm/nth_element.
# - https://github.com/scikit-learn/scikit-learn/pull/11103
# - https://github.com/scikit-learn/scikit-learn/pull/19473
cdef extern from *:
"""
#include <algorithm>
template<class D, class I>
class IndexComparator {
private:
const D *data;
I split_dim, n_features;
public:
IndexComparator(const D *data, const I &split_dim, const I &n_features):
data(data), split_dim(split_dim), n_features(n_features) {}
bool operator()(const I &a, const I &b) const {
D a_value = data[a * n_features + split_dim];
D b_value = data[b * n_features + split_dim];
return a_value == b_value ? a < b : a_value < b_value;
}
};
template<class D, class I>
void partition_node_indices_inner(
const D *data,
I *node_indices,
const I &split_dim,
const I &split_index,
const I &n_features,
const I &n_points) {
IndexComparator<D, I> index_comparator(data, split_dim, n_features);
std::nth_element(
node_indices,
node_indices + split_index,
node_indices + n_points,
index_comparator);
}
"""
void partition_node_indices_inner[D, I](
D *data,
I *node_indices,
I split_dim,
I split_index,
I n_features,
I n_points) except +
cdef int partition_node_indices(
DTYPE_t *data,
ITYPE_t *node_indices,
ITYPE_t split_dim,
ITYPE_t split_index,
ITYPE_t n_features,
ITYPE_t n_points) except -1:
"""Partition points in the node into two equal-sized groups.
Upon return, the values in node_indices will be rearranged such that
(assuming numpy-style indexing):
data[node_indices[0:split_index], split_dim]
<= data[node_indices[split_index], split_dim]
and
data[node_indices[split_index], split_dim]
<= data[node_indices[split_index:n_points], split_dim]
The algorithm is essentially a partial in-place quicksort around a
set pivot.
Parameters
----------
data : double pointer
Pointer to a 2D array of the training data, of shape [N, n_features].
N must be greater than any of the values in node_indices.
node_indices : int pointer
Pointer to a 1D array of length n_points. This lists the indices of
each of the points within the current node. This will be modified
in-place.
split_dim : int
the dimension on which to split. This will usually be computed via
the routine ``find_node_split_dim``.
split_index : int
the index within node_indices around which to split the points.
n_features: int
the number of features (i.e columns) in the 2D array pointed by data.
n_points : int
the length of node_indices. This is also the number of points in
the original dataset.
Returns
-------
status : int
integer exit status. On return, the contents of node_indices are
modified as noted above.
"""
partition_node_indices_inner(
data,
node_indices,
split_dim,
split_index,
n_features,
n_points)
return 0
|