File: _argkmin_classmode.pyx.tp

package info (click to toggle)
scikit-learn 1.4.2%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 25,036 kB
  • sloc: python: 201,105; cpp: 5,790; ansic: 854; makefile: 304; sh: 56; javascript: 20
file content (182 lines) | stat: -rw-r--r-- 6,408 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from cython cimport floating, integral
from cython.parallel cimport parallel, prange
from libcpp.map cimport map as cpp_map, pair as cpp_pair
from libc.stdlib cimport free

from ...utils._typedefs cimport intp_t, float64_t

import numpy as np
from scipy.sparse import issparse
from sklearn.utils.fixes import threadpool_limits
from ._classmode cimport WeightingStrategy

{{for name_suffix in ["32", "64"]}}
from ._argkmin cimport ArgKmin{{name_suffix}}
from ._datasets_pair cimport DatasetsPair{{name_suffix}}

cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
    """
    {{name_suffix}}bit implementation of ArgKminClassMode.
    """
    cdef:
        const intp_t[:] Y_labels,
        const intp_t[:] unique_Y_labels
        float64_t[:, :] class_scores
        cpp_map[intp_t, intp_t] labels_to_index
        WeightingStrategy weight_type

    @classmethod
    def compute(
        cls,
        X,
        Y,
        intp_t k,
        weights,
        Y_labels,
        unique_Y_labels,
        str metric="euclidean",
        chunk_size=None,
        dict metric_kwargs=None,
        str strategy=None,
    ):
        """Compute the argkmin reduction with Y_labels.

        This classmethod is responsible for introspecting the arguments
        values to dispatch to the most appropriate implementation of
        :class:`ArgKminClassMode{{name_suffix}}`.

        This allows decoupling the API entirely from the implementation details
        whilst maintaining RAII: all temporarily allocated datastructures necessary
        for the concrete implementation are therefore freed when this classmethod
        returns.

        No instance _must_ directly be created outside of this class method.
        """
        # Use a generic implementation that handles most scipy
        # metrics by computing the distances between 2 vectors at a time.
        pda = ArgKminClassMode{{name_suffix}}(
            datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs),
            k=k,
            chunk_size=chunk_size,
            strategy=strategy,
            weights=weights,
            Y_labels=Y_labels,
            unique_Y_labels=unique_Y_labels,
        )

        # Limit the number of threads in second level of nested parallelism for BLAS
        # to avoid threads over-subscription (in GEMM for instance).
        with threadpool_limits(limits=1, user_api="blas"):
            if pda.execute_in_parallel_on_Y:
                pda._parallel_on_Y()
            else:
                pda._parallel_on_X()

        return pda._finalize_results()

    def __init__(
        self,
        DatasetsPair{{name_suffix}} datasets_pair,
        const intp_t[:] Y_labels,
        const intp_t[:] unique_Y_labels,
        chunk_size=None,
        strategy=None,
        intp_t k=1,
        weights=None,
    ):
        super().__init__(
            datasets_pair=datasets_pair,
            chunk_size=chunk_size,
            strategy=strategy,
            k=k,
        )

        if weights == "uniform":
            self.weight_type = WeightingStrategy.uniform
        elif weights == "distance":
            self.weight_type = WeightingStrategy.distance
        else:
            self.weight_type = WeightingStrategy.callable
        self.Y_labels = Y_labels

        self.unique_Y_labels = unique_Y_labels

        cdef intp_t idx, neighbor_class_idx
        # Map from set of unique labels to their indices in `class_scores`
        # Buffer used in building a histogram for one-pass weighted mode
        self.class_scores = np.zeros(
            (self.n_samples_X, unique_Y_labels.shape[0]), dtype=np.float64,
        )

    def _finalize_results(self):
        probabilities = np.asarray(self.class_scores)
        probabilities /= probabilities.sum(axis=1, keepdims=True)
        return probabilities

    cdef inline void weighted_histogram_mode(
        self,
        intp_t sample_index,
        intp_t* indices,
        float64_t* distances,
   ) noexcept nogil:
        cdef:
            intp_t neighbor_idx, neighbor_class_idx, label_index, multi_output_index
            float64_t score_incr = 1
            # TODO: Implement other WeightingStrategy values
            bint use_distance_weighting = (
                self.weight_type == WeightingStrategy.distance
            )

        # Iterate through the sample k-nearest neighbours
        for neighbor_rank in range(self.k):
            # Absolute indice of the neighbor_rank-th Nearest Neighbors
            # in range [0, n_samples_Y)
            # TODO: inspect if it worth permuting this condition
            # and the for-loop above for improved branching.
            if use_distance_weighting:
                score_incr = 1 / distances[neighbor_rank]
            neighbor_idx = indices[neighbor_rank]
            neighbor_class_idx = self.Y_labels[neighbor_idx]
            self.class_scores[sample_index][neighbor_class_idx] += score_incr
        return

    cdef void _parallel_on_X_prange_iter_finalize(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        cdef:
            intp_t idx, sample_index
        for idx in range(X_end - X_start):
            # One-pass top-one weighted mode
            # Compute the absolute index in [0, n_samples_X)
            sample_index = X_start + idx
            self.weighted_histogram_mode(
                sample_index,
                &self.heaps_indices_chunks[thread_num][idx * self.k],
                &self.heaps_r_distances_chunks[thread_num][idx * self.k],
            )
        return

    cdef void _parallel_on_Y_finalize(
        self,
    ) noexcept nogil:
        cdef:
            intp_t sample_index, thread_num

        with nogil, parallel(num_threads=self.chunks_n_threads):
            # Deallocating temporary datastructures
            for thread_num in prange(self.chunks_n_threads, schedule='static'):
                free(self.heaps_r_distances_chunks[thread_num])
                free(self.heaps_indices_chunks[thread_num])

            for sample_index in prange(self.n_samples_X, schedule='static'):
                self.weighted_histogram_mode(
                    sample_index,
                    &self.argkmin_indices[sample_index][0],
                    &self.argkmin_distances[sample_index][0],
                )
        return

{{endfor}}