File: _argkmin.pyx.tp

package info (click to toggle)
scikit-learn 1.4.2%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 25,036 kB
  • sloc: python: 201,105; cpp: 5,790; ansic: 854; makefile: 304; sh: 56; javascript: 20
file content (505 lines) | stat: -rw-r--r-- 19,252 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
from libc.stdlib cimport free, malloc
from libc.float cimport DBL_MAX
from cython cimport final
from cython.parallel cimport parallel, prange

from ...utils._heap cimport heap_push
from ...utils._sorting cimport simultaneous_sort
from ...utils._typedefs cimport intp_t, float64_t

import numpy as np
import warnings

from numbers import Integral
from scipy.sparse import issparse
from ...utils import check_array, check_scalar, _in_unstable_openblas_configuration
from ...utils.fixes import threadpool_limits

{{for name_suffix in ['64', '32']}}

from ._base cimport (
    BaseDistancesReduction{{name_suffix}},
    _sqeuclidean_row_norms{{name_suffix}},
)

from ._datasets_pair cimport DatasetsPair{{name_suffix}}

from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}}


cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
    """float{{name_suffix}} implementation of the ArgKmin."""

    @classmethod
    def compute(
        cls,
        X,
        Y,
        intp_t k,
        metric="euclidean",
        chunk_size=None,
        dict metric_kwargs=None,
        str strategy=None,
        bint return_distance=False,
    ):
        """Compute the argkmin reduction.

        This classmethod is responsible for introspecting the arguments
        values to dispatch to the most appropriate implementation of
        :class:`ArgKmin{{name_suffix}}`.

        This allows decoupling the API entirely from the implementation details
        whilst maintaining RAII: all temporarily allocated datastructures necessary
        for the concrete implementation are therefore freed when this classmethod
        returns.

        No instance should directly be created outside of this class method.
        """
        if metric in ("euclidean", "sqeuclidean"):
            # Specialized implementation of ArgKmin for the Euclidean distance
            # for the dense-dense and sparse-sparse cases.
            # This implementation computes the distances by chunk using
            # a decomposition of the Squared Euclidean distance.
            # This specialisation has an improved arithmetic intensity for both
            # the dense and sparse settings, allowing in most case speed-ups of
            # several orders of magnitude compared to the generic ArgKmin
            # implementation.
            # For more information see MiddleTermComputer.
            use_squared_distances = metric == "sqeuclidean"
            pda = EuclideanArgKmin{{name_suffix}}(
                X=X, Y=Y, k=k,
                use_squared_distances=use_squared_distances,
                chunk_size=chunk_size,
                strategy=strategy,
                metric_kwargs=metric_kwargs,
            )
        else:
            # Fall back on a generic implementation that handles most scipy
            # metrics by computing the distances between 2 vectors at a time.
            pda = ArgKmin{{name_suffix}}(
                datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs),
                k=k,
                chunk_size=chunk_size,
                strategy=strategy,
            )

        # Limit the number of threads in second level of nested parallelism for BLAS
        # to avoid threads over-subscription (in GEMM for instance).
        with threadpool_limits(limits=1, user_api="blas"):
            if pda.execute_in_parallel_on_Y:
                pda._parallel_on_Y()
            else:
                pda._parallel_on_X()

        return pda._finalize_results(return_distance)

    def __init__(
        self,
        DatasetsPair{{name_suffix}} datasets_pair,
        chunk_size=None,
        strategy=None,
        intp_t k=1,
    ):
        super().__init__(
            datasets_pair=datasets_pair,
            chunk_size=chunk_size,
            strategy=strategy,
        )
        self.k = check_scalar(k, "k", Integral, min_val=1)

        # Allocating pointers to datastructures but not the datastructures themselves.
        # There are as many pointers as effective threads.
        #
        # For the sake of explicitness:
        #   - when parallelizing on X, the pointers of those heaps are referencing
        #   (with proper offsets) addresses of the two main heaps (see below)
        #   - when parallelizing on Y, the pointers of those heaps are referencing
        #   small heaps which are thread-wise-allocated and whose content will be
        #   merged with the main heaps'.
        self.heaps_r_distances_chunks = <float64_t **> malloc(
            sizeof(float64_t *) * self.chunks_n_threads
        )
        self.heaps_indices_chunks = <intp_t **> malloc(
            sizeof(intp_t *) * self.chunks_n_threads
        )

        # Main heaps which will be returned as results by `ArgKmin{{name_suffix}}.compute`.
        self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=np.intp)
        self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=np.float64)

    def __dealloc__(self):
        if self.heaps_indices_chunks is not NULL:
            free(self.heaps_indices_chunks)

        if self.heaps_r_distances_chunks is not NULL:
            free(self.heaps_r_distances_chunks)

    cdef void _compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil:
        cdef:
            intp_t i, j
            intp_t n_samples_X = X_end - X_start
            intp_t n_samples_Y = Y_end - Y_start
            float64_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num]
            intp_t *heaps_indices = self.heaps_indices_chunks[thread_num]

        # Pushing the distances and their associated indices on a heap
        # which by construction will keep track of the argkmin.
        for i in range(n_samples_X):
            for j in range(n_samples_Y):
                heap_push(
                    values=heaps_r_distances + i * self.k,
                    indices=heaps_indices + i * self.k,
                    size=self.k,
                    val=self.datasets_pair.surrogate_dist(X_start + i, Y_start + j),
                    val_idx=Y_start + j,
                )

    cdef void _parallel_on_X_init_chunk(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        # As this strategy is embarrassingly parallel, we can set each
        # thread's heaps pointer to the proper position on the main heaps.
        self.heaps_r_distances_chunks[thread_num] = &self.argkmin_distances[X_start, 0]
        self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0]

    cdef void _parallel_on_X_prange_iter_finalize(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        cdef:
            intp_t idx

        # Sorting the main heaps portion associated to `X[X_start:X_end]`
        # in ascending order w.r.t the distances.
        for idx in range(X_end - X_start):
            simultaneous_sort(
                self.heaps_r_distances_chunks[thread_num] + idx * self.k,
                self.heaps_indices_chunks[thread_num] + idx * self.k,
                self.k
            )

    cdef void _parallel_on_Y_init(
        self,
    ) noexcept nogil:
        cdef:
            # Maximum number of scalar elements (the last chunks can be smaller)
            intp_t heaps_size = self.X_n_samples_chunk * self.k
            intp_t thread_num

        # The allocation is done in parallel for data locality purposes: this way
        # the heaps used in each threads are allocated in pages which are closer
        # to the CPU core used by the thread.
        # See comments about First Touch Placement Policy:
        # https://www.openmp.org/wp-content/uploads/openmp-webinar-vanderPas-20210318.pdf #noqa
        for thread_num in prange(self.chunks_n_threads, schedule='static', nogil=True,
                                 num_threads=self.chunks_n_threads):
            # As chunks of X are shared across threads, so must their
            # heaps. To solve this, each thread has its own heaps
            # which are then synchronised back in the main ones.
            self.heaps_r_distances_chunks[thread_num] = <float64_t *> malloc(
                heaps_size * sizeof(float64_t)
            )
            self.heaps_indices_chunks[thread_num] = <intp_t *> malloc(
                heaps_size * sizeof(intp_t)
            )

    cdef void _parallel_on_Y_parallel_init(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        # Initialising heaps (memset can't be used here)
        for idx in range(self.X_n_samples_chunk * self.k):
            self.heaps_r_distances_chunks[thread_num][idx] = DBL_MAX
            self.heaps_indices_chunks[thread_num][idx] = -1

    @final
    cdef void _parallel_on_Y_synchronize(
        self,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        cdef:
            intp_t idx, jdx, thread_num
        with nogil, parallel(num_threads=self.effective_n_threads):
            # Synchronising the thread heaps with the main heaps.
            # This is done in parallel sample-wise (no need for locks).
            #
            # This might break each thread's data locality as each heap which
            # was allocated in a thread is being now being used in several threads.
            #
            # Still, this parallel pattern has shown to be efficient in practice.
            for idx in prange(X_end - X_start, schedule="static"):
                for thread_num in range(self.chunks_n_threads):
                    for jdx in range(self.k):
                        heap_push(
                            values=&self.argkmin_distances[X_start + idx, 0],
                            indices=&self.argkmin_indices[X_start + idx, 0],
                            size=self.k,
                            val=self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx],
                            val_idx=self.heaps_indices_chunks[thread_num][idx * self.k + jdx],
                        )

    cdef void _parallel_on_Y_finalize(
        self,
    ) noexcept nogil:
        cdef:
            intp_t idx, thread_num

        with nogil, parallel(num_threads=self.chunks_n_threads):
            # Deallocating temporary datastructures
            for thread_num in prange(self.chunks_n_threads, schedule='static'):
                free(self.heaps_r_distances_chunks[thread_num])
                free(self.heaps_indices_chunks[thread_num])

            # Sorting the main in ascending order w.r.t the distances.
            # This is done in parallel sample-wise (no need for locks).
            for idx in prange(self.n_samples_X, schedule='static'):
                simultaneous_sort(
                    &self.argkmin_distances[idx, 0],
                    &self.argkmin_indices[idx, 0],
                    self.k,
                )
        return

    cdef void compute_exact_distances(self) noexcept nogil:
        cdef:
            intp_t i, j
            float64_t[:, ::1] distances = self.argkmin_distances
        for i in prange(self.n_samples_X, schedule='static', nogil=True,
                        num_threads=self.effective_n_threads):
            for j in range(self.k):
                distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist(
                    # Guard against potential -0., causing nan production.
                    max(distances[i, j], 0.)
                )

    def _finalize_results(self, bint return_distance=False):
        if return_distance:
            # We need to recompute distances because we relied on
            # surrogate distances for the reduction.
            self.compute_exact_distances()

            # Values are returned identically to the way `KNeighborsMixin.kneighbors`
            # returns values. This is counter-intuitive but this allows not using
            # complex adaptations where `ArgKmin.compute` is called.
            return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices)

        return np.asarray(self.argkmin_indices)


cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
    """EuclideanDistance-specialisation of ArgKmin{{name_suffix}}."""

    @classmethod
    def is_usable_for(cls, X, Y, metric) -> bool:
        return (ArgKmin{{name_suffix}}.is_usable_for(X, Y, metric) and
                not _in_unstable_openblas_configuration())

    def __init__(
        self,
        X,
        Y,
        intp_t k,
        bint use_squared_distances=False,
        chunk_size=None,
        strategy=None,
        metric_kwargs=None,
    ):
        if (
            isinstance(metric_kwargs, dict) and
            (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"})
        ):
            warnings.warn(
                f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't "
                f"usable for this case (EuclideanArgKmin64) and will be ignored.",
                UserWarning,
                stacklevel=3,
            )

        super().__init__(
            # The datasets pair here is used for exact distances computations
            datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric="euclidean"),
            chunk_size=chunk_size,
            strategy=strategy,
            k=k,
        )
        cdef:
            intp_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk

        self.middle_term_computer = MiddleTermComputer{{name_suffix}}.get_for(
            X,
            Y,
            self.effective_n_threads,
            self.chunks_n_threads,
            dist_middle_terms_chunks_size,
            n_features=X.shape[1],
            chunk_size=self.chunk_size,
        )

        if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs:
            self.Y_norm_squared = check_array(
                metric_kwargs.pop("Y_norm_squared"),
                ensure_2d=False,
                input_name="Y_norm_squared",
                dtype=np.float64,
            )
        else:
            self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(
                Y,
                self.effective_n_threads,
            )

        if metric_kwargs is not None and "X_norm_squared" in metric_kwargs:
            self.X_norm_squared = check_array(
                metric_kwargs.pop("X_norm_squared"),
                ensure_2d=False,
                input_name="X_norm_squared",
                dtype=np.float64,
            )
        else:
            # Do not recompute norms if datasets are identical.
            self.X_norm_squared = (
                self.Y_norm_squared if X is Y else
                _sqeuclidean_row_norms{{name_suffix}}(
                    X,
                    self.effective_n_threads,
                )
            )

        self.use_squared_distances = use_squared_distances

    @final
    cdef void compute_exact_distances(self) noexcept nogil:
        if not self.use_squared_distances:
            ArgKmin{{name_suffix}}.compute_exact_distances(self)

    @final
    cdef void _parallel_on_X_parallel_init(
        self,
        intp_t thread_num,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num)
        self.middle_term_computer._parallel_on_X_parallel_init(thread_num)

    @final
    cdef void _parallel_on_X_init_chunk(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end)
        self.middle_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end)

    @final
    cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
            self,
            X_start, X_end,
            Y_start, Y_end,
            thread_num,
        )
        self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
            X_start, X_end, Y_start, Y_end, thread_num,
        )

    @final
    cdef void _parallel_on_Y_init(
        self,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_Y_init(self)
        self.middle_term_computer._parallel_on_Y_init()

    @final
    cdef void _parallel_on_Y_parallel_init(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end)
        self.middle_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end)

    @final
    cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
            self,
            X_start, X_end,
            Y_start, Y_end,
            thread_num,
        )
        self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
            X_start, X_end, Y_start, Y_end, thread_num
        )

    @final
    cdef void _compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil:
        cdef:
            intp_t i, j
            float64_t sqeuclidean_dist_i_j
            intp_t n_X = X_end - X_start
            intp_t n_Y = Y_end - Y_start
            float64_t * dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms(
                X_start, X_end, Y_start, Y_end, thread_num
            )
            float64_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num]
            intp_t * heaps_indices = self.heaps_indices_chunks[thread_num]

        # Pushing the distance and their associated indices on heaps
        # which keep tracks of the argkmin.
        for i in range(n_X):
            for j in range(n_Y):
                sqeuclidean_dist_i_j = (
                    self.X_norm_squared[i + X_start] +
                    dist_middle_terms[i * n_Y + j] +
                    self.Y_norm_squared[j + Y_start]
                )

                # Catastrophic cancellation might cause -0. to be present,
                # e.g. when computing d(x_i, y_i) when X is Y.
                sqeuclidean_dist_i_j = max(0., sqeuclidean_dist_i_j)

                heap_push(
                    values=heaps_r_distances + i * self.k,
                    indices=heaps_indices + i * self.k,
                    size=self.k,
                    val=sqeuclidean_dist_i_j,
                    val_idx=j + Y_start,
                )

{{endfor}}