File: _k_means_elkan.pyx

package info (click to toggle)
scikit-learn 1.4.2%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 25,036 kB
  • sloc: python: 201,105; cpp: 5,790; ansic: 854; makefile: 304; sh: 56; javascript: 20
file content (687 lines) | stat: -rw-r--r-- 28,135 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
# Author: Andreas Mueller
#
# Licence: BSD 3 clause

from cython cimport floating
from cython.parallel import prange, parallel
from libc.stdlib cimport calloc, free
from libc.string cimport memset

from ..utils._openmp_helpers cimport omp_lock_t
from ..utils._openmp_helpers cimport omp_init_lock
from ..utils._openmp_helpers cimport omp_destroy_lock
from ..utils._openmp_helpers cimport omp_set_lock
from ..utils._openmp_helpers cimport omp_unset_lock
from ..utils.extmath import row_norms
from ._k_means_common import CHUNK_SIZE
from ._k_means_common cimport _relocate_empty_clusters_dense
from ._k_means_common cimport _relocate_empty_clusters_sparse
from ._k_means_common cimport _euclidean_dense_dense
from ._k_means_common cimport _euclidean_sparse_dense
from ._k_means_common cimport _average_centers
from ._k_means_common cimport _center_shift


def init_bounds_dense(
        const floating[:, ::1] X,                      # IN
        const floating[:, ::1] centers,                # IN
        const floating[:, ::1] center_half_distances,  # IN
        int[::1] labels,                               # OUT
        floating[::1] upper_bounds,                    # OUT
        floating[:, ::1] lower_bounds,                 # OUT
        int n_threads):
    """Initialize upper and lower bounds for each sample for dense input data.

    Given X, centers and the pairwise distances divided by 2.0 between the
    centers this calculates the upper bounds and lower bounds for each sample.
    The upper bound for each sample is set to the distance between the sample
    and the closest center.

    The lower bound for each sample is a one-dimensional array of n_clusters.
    For each sample i assume that the previously assigned cluster is c1 and the
    previous closest distance is dist, for a new cluster c2, the
    lower_bound[i][c2] is set to distance between the sample and this new
    cluster, if and only if dist > center_half_distances[c1][c2]. This prevents
    computation of unnecessary distances for each sample to the clusters that
    it is unlikely to be assigned to.

    Parameters
    ----------
    X : ndarray of shape (n_samples, n_features), dtype=floating
        The input data.

    centers : ndarray of shape (n_clusters, n_features), dtype=floating
        The cluster centers.

    center_half_distances : ndarray of shape (n_clusters, n_clusters), \
            dtype=floating
        The half of the distance between any 2 clusters centers.

    labels : ndarray of shape(n_samples), dtype=int
        The label for each sample. This array is modified in place.

    upper_bounds : ndarray of shape(n_samples,), dtype=floating
        The upper bound on the distance between each sample and its closest
        cluster center. This array is modified in place.

    lower_bounds : ndarray, of shape(n_samples, n_clusters), dtype=floating
        The lower bound on the distance between each sample and each cluster
        center. This array is modified in place.

    n_threads : int
        The number of threads to be used by openmp.
    """
    cdef:
        int n_samples = X.shape[0]
        int n_clusters = centers.shape[0]
        int n_features = X.shape[1]

        floating min_dist, dist
        int best_cluster, i, j

    for i in prange(
        n_samples, num_threads=n_threads, schedule='static', nogil=True
    ):
        best_cluster = 0
        min_dist = _euclidean_dense_dense(&X[i, 0], &centers[0, 0],
                                          n_features, False)
        lower_bounds[i, 0] = min_dist
        for j in range(1, n_clusters):
            if min_dist > center_half_distances[best_cluster, j]:
                dist = _euclidean_dense_dense(&X[i, 0], &centers[j, 0],
                                              n_features, False)
                lower_bounds[i, j] = dist
                if dist < min_dist:
                    min_dist = dist
                    best_cluster = j
        labels[i] = best_cluster
        upper_bounds[i] = min_dist


def init_bounds_sparse(
        X,                                             # IN
        const floating[:, ::1] centers,                # IN
        const floating[:, ::1] center_half_distances,  # IN
        int[::1] labels,                               # OUT
        floating[::1] upper_bounds,                    # OUT
        floating[:, ::1] lower_bounds,                 # OUT
        int n_threads):
    """Initialize upper and lower bounds for each sample for sparse input data.

    Given X, centers and the pairwise distances divided by 2.0 between the
    centers this calculates the upper bounds and lower bounds for each sample.
    The upper bound for each sample is set to the distance between the sample
    and the closest center.

    The lower bound for each sample is a one-dimensional array of n_clusters.
    For each sample i assume that the previously assigned cluster is c1 and the
    previous closest distance is dist, for a new cluster c2, the
    lower_bound[i][c2] is set to distance between the sample and this new
    cluster, if and only if dist > center_half_distances[c1][c2]. This prevents
    computation of unnecessary distances for each sample to the clusters that
    it is unlikely to be assigned to.

    Parameters
    ----------
    X : sparse matrix of shape (n_samples, n_features), dtype=floating
        The input data. Must be in CSR format.

    centers : ndarray of shape (n_clusters, n_features), dtype=floating
        The cluster centers.

    center_half_distances : ndarray of shape (n_clusters, n_clusters), \
            dtype=floating
        The half of the distance between any 2 clusters centers.

    labels : ndarray of shape(n_samples), dtype=int
        The label for each sample. This array is modified in place.

    upper_bounds : ndarray of shape(n_samples,), dtype=floating
        The upper bound on the distance between each sample and its closest
        cluster center. This array is modified in place.

    lower_bounds : ndarray of shape(n_samples, n_clusters), dtype=floating
        The lower bound on the distance between each sample and each cluster
        center. This array is modified in place.

    n_threads : int
        The number of threads to be used by openmp.
    """
    cdef:
        int n_samples = X.shape[0]
        int n_clusters = centers.shape[0]

        floating[::1] X_data = X.data
        int[::1] X_indices = X.indices
        int[::1] X_indptr = X.indptr

        floating min_dist, dist
        int best_cluster, i, j

        floating[::1] centers_squared_norms = row_norms(centers, squared=True)

    for i in prange(
        n_samples, num_threads=n_threads, schedule='static', nogil=True
    ):
        best_cluster = 0
        min_dist = _euclidean_sparse_dense(
            X_data[X_indptr[i]: X_indptr[i + 1]],
            X_indices[X_indptr[i]: X_indptr[i + 1]],
            centers[0], centers_squared_norms[0], False)

        lower_bounds[i, 0] = min_dist
        for j in range(1, n_clusters):
            if min_dist > center_half_distances[best_cluster, j]:
                dist = _euclidean_sparse_dense(
                    X_data[X_indptr[i]: X_indptr[i + 1]],
                    X_indices[X_indptr[i]: X_indptr[i + 1]],
                    centers[j], centers_squared_norms[j], False)
                lower_bounds[i, j] = dist
                if dist < min_dist:
                    min_dist = dist
                    best_cluster = j
        labels[i] = best_cluster
        upper_bounds[i] = min_dist


def elkan_iter_chunked_dense(
        const floating[:, ::1] X,                      # IN
        const floating[::1] sample_weight,             # IN
        const floating[:, ::1] centers_old,            # IN
        floating[:, ::1] centers_new,                  # OUT
        floating[::1] weight_in_clusters,              # OUT
        const floating[:, ::1] center_half_distances,  # IN
        const floating[::1] distance_next_center,      # IN
        floating[::1] upper_bounds,                    # INOUT
        floating[:, ::1] lower_bounds,                 # INOUT
        int[::1] labels,                               # INOUT
        floating[::1] center_shift,                    # OUT
        int n_threads,
        bint update_centers=True):
    """Single iteration of K-means Elkan algorithm with dense input.

    Update labels and centers (inplace), for one iteration, distributed
    over data chunks.

    Parameters
    ----------
    X : ndarray of shape (n_samples, n_features), dtype=floating
        The observations to cluster.

    sample_weight : ndarray of shape (n_samples,), dtype=floating
        The weights for each observation in X.

    centers_old : ndarray of shape (n_clusters, n_features), dtype=floating
        Centers before previous iteration, placeholder for the centers after
        previous iteration.

    centers_new : ndarray of shape (n_clusters, n_features), dtype=floating
        Centers after previous iteration, placeholder for the new centers
        computed during this iteration.

    weight_in_clusters : ndarray of shape (n_clusters,), dtype=floating
        Placeholder for the sums of the weights of every observation assigned
        to each center.

    center_half_distances : ndarray of shape (n_clusters, n_clusters), \
            dtype=floating
        Half pairwise distances between centers.

    distance_next_center : ndarray of shape (n_clusters,), dtype=floating
        Distance between each center its closest center.

    upper_bounds : ndarray of shape (n_samples,), dtype=floating
        Upper bound for the distance between each sample and its center,
        updated inplace.

    lower_bounds : ndarray of shape (n_samples, n_clusters), dtype=floating
        Lower bound for the distance between each sample and each center,
        updated inplace.

    labels : ndarray of shape (n_samples,), dtype=int
        labels assignment.

    center_shift : ndarray of shape (n_clusters,), dtype=floating
        Distance between old and new centers.

    n_threads : int
        The number of threads to be used by openmp.

    update_centers : bool
        - If True, the labels and the new centers will be computed, i.e. runs
          the E-step and the M-step of the algorithm.
        - If False, only the labels will be computed, i.e runs the E-step of
          the algorithm. This is useful especially when calling predict on a
          fitted model.
    """
    cdef:
        int n_samples = X.shape[0]
        int n_features = X.shape[1]
        int n_clusters = centers_new.shape[0]

    if n_samples == 0:
        # An empty array was passed, do nothing and return early (before
        # attempting to compute n_chunks). This can typically happen when
        # calling the prediction function of a bisecting k-means model with a
        # large fraction of outiers.
        return

    cdef:
        # hard-coded number of samples per chunk. Splitting in chunks is
        # necessary to get parallelism. Chunk size chosen to be same as lloyd's
        int n_samples_chunk = CHUNK_SIZE if n_samples > CHUNK_SIZE else n_samples
        int n_chunks = n_samples // n_samples_chunk
        int n_samples_rem = n_samples % n_samples_chunk
        int chunk_idx
        int start, end

        int i, j, k

        floating *centers_new_chunk
        floating *weight_in_clusters_chunk

        omp_lock_t lock

    # count remainder chunk in total number of chunks
    n_chunks += n_samples != n_chunks * n_samples_chunk

    # number of threads should not be bigger than number of chunks
    n_threads = min(n_threads, n_chunks)

    if update_centers:
        memset(&centers_new[0, 0], 0, n_clusters * n_features * sizeof(floating))
        memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating))
        omp_init_lock(&lock)

    with nogil, parallel(num_threads=n_threads):
        # thread local buffers
        centers_new_chunk = <floating*> calloc(n_clusters * n_features, sizeof(floating))
        weight_in_clusters_chunk = <floating*> calloc(n_clusters, sizeof(floating))

        for chunk_idx in prange(n_chunks, schedule='static'):
            start = chunk_idx * n_samples_chunk
            if chunk_idx == n_chunks - 1 and n_samples_rem > 0:
                end = start + n_samples_rem
            else:
                end = start + n_samples_chunk

            _update_chunk_dense(
                X[start: end],
                sample_weight[start: end],
                centers_old,
                center_half_distances,
                distance_next_center,
                labels[start: end],
                upper_bounds[start: end],
                lower_bounds[start: end],
                centers_new_chunk,
                weight_in_clusters_chunk,
                update_centers)

        # reduction from local buffers.
        if update_centers:
            # The lock is necessary to avoid race conditions when aggregating
            # info from different thread-local buffers.
            omp_set_lock(&lock)
            for j in range(n_clusters):
                weight_in_clusters[j] += weight_in_clusters_chunk[j]
                for k in range(n_features):
                    centers_new[j, k] += centers_new_chunk[j * n_features + k]
            omp_unset_lock(&lock)

        free(centers_new_chunk)
        free(weight_in_clusters_chunk)

    if update_centers:
        omp_destroy_lock(&lock)
        _relocate_empty_clusters_dense(X, sample_weight, centers_old,
                                       centers_new, weight_in_clusters, labels)

        _average_centers(centers_new, weight_in_clusters)
        _center_shift(centers_old, centers_new, center_shift)

        # update lower and upper bounds
        for i in range(n_samples):
            upper_bounds[i] += center_shift[labels[i]]

            for j in range(n_clusters):
                lower_bounds[i, j] -= center_shift[j]
                if lower_bounds[i, j] < 0:
                    lower_bounds[i, j] = 0


cdef void _update_chunk_dense(
        const floating[:, ::1] X,                      # IN
        const floating[::1] sample_weight,             # IN
        const floating[:, ::1] centers_old,            # IN
        const floating[:, ::1] center_half_distances,  # IN
        const floating[::1] distance_next_center,      # IN
        int[::1] labels,                               # INOUT
        floating[::1] upper_bounds,                    # INOUT
        floating[:, ::1] lower_bounds,                 # INOUT
        floating *centers_new,                         # OUT
        floating *weight_in_clusters,                  # OUT
        bint update_centers) noexcept nogil:
    """K-means combined EM step for one dense data chunk.

    Compute the partial contribution of a single data chunk to the labels and
    centers.
    """
    cdef:
        int n_samples = labels.shape[0]
        int n_clusters = centers_old.shape[0]
        int n_features = centers_old.shape[1]

        floating upper_bound, distance
        int i, j, k, label

    for i in range(n_samples):
        upper_bound = upper_bounds[i]
        bounds_tight = 0
        label = labels[i]

        # Next center is not far away from the currently assigned center.
        # Sample might need to be assigned to another center.
        if not distance_next_center[label] >= upper_bound:

            for j in range(n_clusters):

                # If this holds, then center_index is a good candidate for the
                # sample to be relabelled, and we need to confirm this by
                # recomputing the upper and lower bounds.
                if (
                    j != label
                    and (upper_bound > lower_bounds[i, j])
                    and (upper_bound > center_half_distances[label, j])
                ):

                    # Recompute upper bound by calculating the actual distance
                    # between the sample and its current assigned center.
                    if not bounds_tight:
                        upper_bound = _euclidean_dense_dense(
                            &X[i, 0], &centers_old[label, 0], n_features, False)
                        lower_bounds[i, label] = upper_bound
                        bounds_tight = 1

                    # If the condition still holds, then compute the actual
                    # distance between the sample and center. If this is less
                    # than the previous distance, reassign label.
                    if (
                        upper_bound > lower_bounds[i, j]
                        or (upper_bound > center_half_distances[label, j])
                    ):

                        distance = _euclidean_dense_dense(
                            &X[i, 0], &centers_old[j, 0], n_features, False)
                        lower_bounds[i, j] = distance
                        if distance < upper_bound:
                            label = j
                            upper_bound = distance

            labels[i] = label
            upper_bounds[i] = upper_bound

        if update_centers:
            weight_in_clusters[label] += sample_weight[i]
            for k in range(n_features):
                centers_new[label * n_features + k] += X[i, k] * sample_weight[i]


def elkan_iter_chunked_sparse(
        X,                                             # IN
        const floating[::1] sample_weight,             # IN
        const floating[:, ::1] centers_old,            # IN
        floating[:, ::1] centers_new,                  # OUT
        floating[::1] weight_in_clusters,              # OUT
        const floating[:, ::1] center_half_distances,  # IN
        const floating[::1] distance_next_center,      # IN
        floating[::1] upper_bounds,                    # INOUT
        floating[:, ::1] lower_bounds,                 # INOUT
        int[::1] labels,                               # INOUT
        floating[::1] center_shift,                    # OUT
        int n_threads,
        bint update_centers=True):
    """Single iteration of K-means Elkan algorithm with sparse input.

    Update labels and centers (inplace), for one iteration, distributed
    over data chunks.

    Parameters
    ----------
    X : sparse matrix of shape (n_samples, n_features)
        The observations to cluster. Must be in CSR format.

    sample_weight : ndarray of shape (n_samples,), dtype=floating
        The weights for each observation in X.

    centers_old : ndarray of shape (n_clusters, n_features), dtype=floating
        Centers before previous iteration, placeholder for the centers after
        previous iteration.

    centers_new : ndarray of shape (n_clusters, n_features), dtype=floating
        Centers after previous iteration, placeholder for the new centers
        computed during this iteration.

    weight_in_clusters : ndarray of shape (n_clusters,), dtype=floating
        Placeholder for the sums of the weights of every observation assigned
        to each center.

    center_half_distances : ndarray of shape (n_clusters, n_clusters), \
            dtype=floating
        Half pairwise distances between centers.

    distance_next_center : ndarray of shape (n_clusters,), dtype=floating
        Distance between each center its closest center.

    upper_bounds : ndarray of shape (n_samples,), dtype=floating
        Upper bound for the distance between each sample and its center,
        updated inplace.

    lower_bounds : ndarray of shape (n_samples, n_clusters), dtype=floating
        Lower bound for the distance between each sample and each center,
        updated inplace.

    labels : ndarray of shape (n_samples,), dtype=int
        labels assignment.

    center_shift : ndarray of shape (n_clusters,), dtype=floating
        Distance between old and new centers.

    n_threads : int
        The number of threads to be used by openmp.

    update_centers : bool
        - If True, the labels and the new centers will be computed, i.e. runs
          the E-step and the M-step of the algorithm.
        - If False, only the labels will be computed, i.e runs the E-step of
          the algorithm. This is useful especially when calling predict on a
          fitted model.
    """
    cdef:
        int n_samples = X.shape[0]
        int n_features = X.shape[1]
        int n_clusters = centers_new.shape[0]

    if n_samples == 0:
        # An empty array was passed, do nothing and return early (before
        # attempting to compute n_chunks). This can typically happen when
        # calling the prediction function of a bisecting k-means model with a
        # large fraction of outiers.
        return

    cdef:
        floating[::1] X_data = X.data
        int[::1] X_indices = X.indices
        int[::1] X_indptr = X.indptr

        # hard-coded number of samples per chunk. Splitting in chunks is
        # necessary to get parallelism. Chunk size chosen to be same as lloyd's
        int n_samples_chunk = CHUNK_SIZE if n_samples > CHUNK_SIZE else n_samples
        int n_chunks = n_samples // n_samples_chunk
        int n_samples_rem = n_samples % n_samples_chunk
        int chunk_idx
        int start, end

        int i, j, k

        floating[::1] centers_squared_norms = row_norms(centers_old, squared=True)

        floating *centers_new_chunk
        floating *weight_in_clusters_chunk

        omp_lock_t lock

    # count remainder chunk in total number of chunks
    n_chunks += n_samples != n_chunks * n_samples_chunk

    # number of threads should not be bigger than number of chunks
    n_threads = min(n_threads, n_chunks)

    if update_centers:
        memset(&centers_new[0, 0], 0, n_clusters * n_features * sizeof(floating))
        memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating))
        omp_init_lock(&lock)

    with nogil, parallel(num_threads=n_threads):
        # thread local buffers
        centers_new_chunk = <floating*> calloc(n_clusters * n_features, sizeof(floating))
        weight_in_clusters_chunk = <floating*> calloc(n_clusters, sizeof(floating))

        for chunk_idx in prange(n_chunks, schedule='static'):
            start = chunk_idx * n_samples_chunk
            if chunk_idx == n_chunks - 1 and n_samples_rem > 0:
                end = start + n_samples_rem
            else:
                end = start + n_samples_chunk

            _update_chunk_sparse(
                X_data[X_indptr[start]: X_indptr[end]],
                X_indices[X_indptr[start]: X_indptr[end]],
                X_indptr[start: end+1],
                sample_weight[start: end],
                centers_old,
                centers_squared_norms,
                center_half_distances,
                distance_next_center,
                labels[start: end],
                upper_bounds[start: end],
                lower_bounds[start: end],
                centers_new_chunk,
                weight_in_clusters_chunk,
                update_centers)

        # reduction from local buffers.
        if update_centers:
            # The lock is necessary to avoid race conditions when aggregating
            # info from different thread-local buffers.
            omp_set_lock(&lock)
            for j in range(n_clusters):
                weight_in_clusters[j] += weight_in_clusters_chunk[j]
                for k in range(n_features):
                    centers_new[j, k] += centers_new_chunk[j * n_features + k]
            omp_unset_lock(&lock)

        free(centers_new_chunk)
        free(weight_in_clusters_chunk)

    if update_centers:
        omp_destroy_lock(&lock)
        _relocate_empty_clusters_sparse(
            X_data, X_indices, X_indptr, sample_weight,
            centers_old, centers_new, weight_in_clusters, labels)

        _average_centers(centers_new, weight_in_clusters)
        _center_shift(centers_old, centers_new, center_shift)

        # update lower and upper bounds
        for i in range(n_samples):
            upper_bounds[i] += center_shift[labels[i]]

            for j in range(n_clusters):
                lower_bounds[i, j] -= center_shift[j]
                if lower_bounds[i, j] < 0:
                    lower_bounds[i, j] = 0


cdef void _update_chunk_sparse(
        const floating[::1] X_data,                    # IN
        const int[::1] X_indices,                      # IN
        const int[::1] X_indptr,                       # IN
        const floating[::1] sample_weight,             # IN
        const floating[:, ::1] centers_old,            # IN
        const floating[::1] centers_squared_norms,     # IN
        const floating[:, ::1] center_half_distances,  # IN
        const floating[::1] distance_next_center,      # IN
        int[::1] labels,                               # INOUT
        floating[::1] upper_bounds,                    # INOUT
        floating[:, ::1] lower_bounds,                 # INOUT
        floating *centers_new,                         # OUT
        floating *weight_in_clusters,                  # OUT
        bint update_centers) noexcept nogil:
    """K-means combined EM step for one sparse data chunk.

    Compute the partial contribution of a single data chunk to the labels and
    centers.
    """
    cdef:
        int n_samples = labels.shape[0]
        int n_clusters = centers_old.shape[0]
        int n_features = centers_old.shape[1]

        floating upper_bound, distance
        int i, j, k, label
        int s = X_indptr[0]

    for i in range(n_samples):
        upper_bound = upper_bounds[i]
        bounds_tight = 0
        label = labels[i]

        # Next center is not far away from the currently assigned center.
        # Sample might need to be assigned to another center.
        if not distance_next_center[label] >= upper_bound:

            for j in range(n_clusters):

                # If this holds, then center_index is a good candidate for the
                # sample to be relabelled, and we need to confirm this by
                # recomputing the upper and lower bounds.
                if (
                    j != label
                    and (upper_bound > lower_bounds[i, j])
                    and (upper_bound > center_half_distances[label, j])
                ):

                    # Recompute upper bound by calculating the actual distance
                    # between the sample and its current assigned center.
                    if not bounds_tight:
                        upper_bound = _euclidean_sparse_dense(
                            X_data[X_indptr[i] - s: X_indptr[i + 1] - s],
                            X_indices[X_indptr[i] - s: X_indptr[i + 1] - s],
                            centers_old[label], centers_squared_norms[label], False)
                        lower_bounds[i, label] = upper_bound
                        bounds_tight = 1

                    # If the condition still holds, then compute the actual
                    # distance between the sample and center. If this is less
                    # than the previous distance, reassign label.
                    if (
                        upper_bound > lower_bounds[i, j]
                        or (upper_bound > center_half_distances[label, j])
                    ):
                        distance = _euclidean_sparse_dense(
                            X_data[X_indptr[i] - s: X_indptr[i + 1] - s],
                            X_indices[X_indptr[i] - s: X_indptr[i + 1] - s],
                            centers_old[j], centers_squared_norms[j], False)
                        lower_bounds[i, j] = distance
                        if distance < upper_bound:
                            label = j
                            upper_bound = distance

            labels[i] = label
            upper_bounds[i] = upper_bound

        if update_centers:
            weight_in_clusters[label] += sample_weight[i]
            for k in range(X_indptr[i] - s, X_indptr[i + 1] - s):
                centers_new[label * n_features + X_indices[k]] += X_data[k] * sample_weight[i]