File: _segment_algorithms.py

package info (click to toggle)
python-awkward 2.9.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 35,360 kB
  • sloc: python: 187,941; cpp: 33,672; sh: 432; ansic: 256; makefile: 21; javascript: 8
file content (298 lines) | stat: -rw-r--r-- 11,863 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import cupy as cp
import numpy as np
import nvtx

from cuda.compute import (
    PermutationIterator, ZipIterator, TransformIterator, CountingIterator,
    OpKind, exclusive_scan, select, unary_transform
)


@nvtx.annotate("segment_sizes")
def segment_sizes(offsets):
    """
    Compute the size of each segment from segment offsets.

    Args:
        offsets: Device array of segment offsets (length = num_segments + 1).
                 Each segment i contains elements from offsets[i] to offsets[i+1].

    Returns:
        Device array of segment sizes (length = num_segments).
    """
    return offsets[1:] - offsets[:-1]


@nvtx.annotate("offsets_to_segment_ids")
def offsets_to_segment_ids(offsets, stream=None):
    """
    Convert segment offsets to segment IDs (indicators).

    Given offsets [0, 2, 5, 8, 10], produces [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
    """
    num_elements = int(offsets[-1])

    if num_elements == 0:
        return cp.array([], dtype=np.int32)

    # TODO: when available this can be a fused CountingIterator + lower_bound

    # Create array of all element indices [0, 1, 2, ..., num_elements-1]
    element_indices = cp.arange(num_elements, dtype=np.int32)

    # Use binary search to find which segment each element belongs to
    # searchsorted finds the rightmost position where we can insert each element index
    # such that the array remains sorted. Subtracting 1 gives us the segment ID.
    #
    # Example: offsets = [0, 0, 2, 3], element_indices = [0, 1, 2]
    # searchsorted(offsets[:-1], [0, 1, 2], side='right') = [1, 1, 2]
    # Subtracting 1: [0, 0, 1] - but wait, element 0 should be in segment 1!
    #
    # Actually, we want: offsets[j] <= i < offsets[j+1]
    # So searchsorted with side='right' on offsets (not offsets[:-1]) and subtract 1
    segment_ids = cp.searchsorted(offsets, element_indices, side="right") - 1

    return segment_ids


@nvtx.annotate("select_segments")
def select_segments(
    data_in,
    offsets_in,
    mask_in,
    data_out,
    offsets_out,
    d_num_selected_out,
    num_elements,
    num_segments,
    stream=None,
):
    """
    Select segments of a segmented array using a per-segment mask.

    A segmented array is conceptually composed of a data array and segment offsets.
    For example, with data=[30, 20, 20, 50, 90, 10, 30, 80, 20, 60] and
    offsets=[0, 2, 5, 8, 10], the segmented array represents:
    [[30, 20], [20, 50, 90], [10, 30, 80], [20, 60]].
    Given the mask [0, 1, 0, 1], the function will return the segmented array
    [[20, 50, 90], [20, 60]].

    Example:
        >>> data_in = cp.array([30, 20, 20, 50, 90, 10, 30, 80, 20, 60], dtype=np.int32)
        >>> offsets_in = cp.array([0, 2, 5, 8, 10], dtype=np.int32)
        >>> mask_in = cp.array([0, 1, 0, 1], dtype=np.int8)
        >>> data_out = cp.empty_like(data_in)
        >>> offsets_out = cp.empty_like(offsets_in)
        >>> d_num_selected = cp.zeros(2, dtype=np.int32)
        >>> select_segments(data_in, offsets_in, mask_in, data_out, offsets_out,
        ...                 d_num_selected, len(data_in), len(offsets_in) - 1)
        >>> # Result: data_out contains [20, 50, 90, 20, 60, ...]
        >>> #         offsets_out contains [0, 3, 5, ...]
        >>> #         d_num_selected[0] = 5 (number of data elements)
        >>> #         d_num_selected[1] = 2 (number of segments kept)

    Args:
        data_in: Device array or iterator containing all segment elements concatenated.
        offsets_in: Device array or iterator of segment offsets (length = num_segments + 1).
                    Each segment i contains elements from offsets[i] to offsets[i+1].
        mask_in: Device array or iterator (int8) indicating which segments to keep (length = num_segments).
                 Non-zero values indicate segments to keep.
        data_out: Device array or iterator to store selected data elements.
                  Should be pre-allocated with size >= num_elements.
        offsets_out: Device array or iterator to store new segment offsets.
                     Should be pre-allocated with size >= num_segments + 1.
        d_num_selected_out: Device array to store counts (length >= 2):
                           - d_num_selected_out[0]: number of selected data elements
                           - d_num_selected_out[1]: number of segments kept
        num_elements: Total number of elements in data_in.
        num_segments: Total number of segments (= len(offsets_in) - 1).
        stream: CUDA stream for the operation (optional).
    """
    # Step 1: Create segment_indices array indicating which segment each element belongs to
    segment_indices = offsets_to_segment_ids(offsets_in, stream)

    # Step 2: Expand mask from per-segment to per-element using PermutationIterator
    # Each element gets the mask value of its corresponding segment
    expanded_mask_it = PermutationIterator(mask_in, segment_indices)

    # Step 3: Filter the data array and capture indices in a single select call
    # Zip together data, expanded mask, and counting iterator
    data_mask_idx_in = ZipIterator(
        data_in, expanded_mask_it, CountingIterator(np.int32(0)))
    d_selected_indices = cp.empty(num_elements, dtype=np.int32)
    data_idx_out = ZipIterator(data_out, d_selected_indices)
    d_num_data_selected = cp.zeros(1, dtype=np.int32)

    # Define predicate that checks if mask value is non-zero
    def mask_predicate(triple):
        return triple[1] != 0

    # Apply select to get both data and indices where mask is non-zero
    select(
        data_mask_idx_in,
        data_idx_out,
        d_num_data_selected,
        mask_predicate,
        num_elements,
        stream,
    )

    # Get the actual number of selected elements
    num_selected = int(d_num_data_selected[0])
    d_selected_indices = d_selected_indices[:num_selected]

    # Step 4: Compute new segment offsets using the captured indices
    # TODO: this part should use run_length_encode when available

    # Use searchsorted to count elements per segment
    # Use side='left' to count elements strictly less than each offset boundary
    positions = cp.searchsorted(d_selected_indices, offsets_in, side='left')
    segment_counts = (positions[1:] - positions[:-1]).astype(np.int32)

    # Select out the segment sizes where mask is non-zero (those segments are kept)
    # Convert mask to a regular array if needed and use boolean indexing
    mask_array = cp.asarray(mask_in)
    kept_segment_sizes = segment_counts[mask_array != 0]
    num_kept_segments = len(kept_segment_sizes)

    # Exclusive scan to convert sizes to offsets
    temp_offsets = cp.zeros(num_kept_segments + 1, dtype=np.int32)

    if num_kept_segments > 0:
        h_init_scan = np.array([0], dtype=np.int32)
        exclusive_scan(
            kept_segment_sizes,
            offsets_out,
            OpKind.PLUS,
            h_init_scan,
            num_kept_segments,
            stream,
        )

    # Set the final offset to the total number of selected elements
    offsets_out[num_kept_segments] = num_selected

    # Store the counts in d_num_selected_out
    d_num_selected_out[0] = num_selected  # number of data elements
    d_num_selected_out[1] = num_kept_segments  # number of segments kept


@nvtx.annotate("segmented_select")
def segmented_select(
    d_in_data,
    d_in_segments,
    d_out_data,
    d_out_segments,
    cond,
    num_items: int,
    stream=None,
) -> int:
    """
    Select data within segments independently based on a condition.

    Given segmented input data and a selection condition, this function
    applies the selection to each segment independently and produces compacted
    output with updated segment offsets.

    Args:
        d_in_data: Device array containing the input data items.
        d_in_segments: Device array of segment offsets. For N segments,
            this array has N+1 elements where segments[i:i+1] defines
            the range [start, end) for segment i.
        d_out_data: Device array to store selected data (pre-allocated,
            should be at least as large as d_in_data).
        d_out_segments: Device array to store output segment offsets
            (pre-allocated, same size as d_in_segments).
        cond: Callable that takes a data item and returns a boolean-like
            value (typically uint8) indicating whether to keep the item.
        num_items: Total number of items in d_in_data.
        stream: CUDA stream for the operation (optional).

    Returns:
        int: Total number of items after selection (equal to d_out_segments[-1]).

    Example:
        >>> # Input: [[45], [25, 35], [15]] with condition x > 30
        >>> # Output: [[45], [35], []] -> offsets [0, 1, 2, 2]
        >>> def greater_than_30(x):
        ...     return x > 30
        >>> d_in_data = cp.array([45, 25, 35, 15], dtype=cp.int32)
        >>> d_in_segments = cp.array([0, 1, 3, 4], dtype=cp.int32)
        >>> d_out_data = cp.empty_like(d_in_data)
        >>> d_out_segments = cp.empty_like(d_in_segments)
        >>> total = segmented_select(
        ...     d_in_data, d_in_segments, d_out_data, d_out_segments,
        ...     greater_than_30, len(d_in_data)
        ... )
        >>> print(total)  # 2
        >>> print(d_out_segments.get())  # [0, 1, 2, 2]
    """
    import numba.cuda

    num_segments = len(d_in_segments) - 1

    cond = numba.cuda.jit(cond)

    # Step 1: Apply select to get the data and indices where condition is true

    def select_predicate(pair):
        return cond(pair[0])

    data_idx_in = ZipIterator(d_in_data, CountingIterator(np.int32(0)))
    d_indices_out = cp.empty(num_items, dtype=np.int32)
    data_idx_out = ZipIterator(d_out_data, d_indices_out)
    d_num_selected = cp.zeros(1, dtype=cp.uint64)
    select(data_idx_in, data_idx_out,
           d_num_selected, select_predicate, num_items, stream)

    total_selected = int(d_num_selected[0])
    d_indices_out = d_indices_out[:total_selected]
    d_selected_indices = d_indices_out[:total_selected]

    # Step 2: Use searchsorted to count selected items per segment
    # Use side='left' to count elements strictly less than each offset boundary
    positions = cp.searchsorted(
        d_selected_indices, d_in_segments, side='left')
    d_counts = (positions[1:] - positions[:-1]).astype(cp.uint64)

    # Step 3: Use exclusive scan to compute output segment start offsets
    exclusive_scan(
        d_counts,
        d_out_segments[:-1],
        OpKind.PLUS,
        np.array(0, dtype=np.uint64),
        num_segments,
        stream,
    )

    # Set the final offset to the total count
    d_out_segments[-1] = total_selected
    return total_selected


@nvtx.annotate("transform_segments")
def transform_segments(data_in, data_out, segment_size, op, num_segments):
    """
    Given a segmented array where each segment contains the same number of items,
    transform each segment independently using the given n-ary operation.

    For example, given the segmented array [[1, 2, 3], [4, 5, 6], [7, 8, 9]] and the
    operation x + y + z, the function will return the segmented array [[10], [15], [24]].
    """

    def get_column(it, i):
        # return an iterator representing the i-th column of the segmented array.
        def col_major_index(j: np.int32) -> np.int32:
            # given the row major index j, return the column major index.
            return j * segment_size + i
        return PermutationIterator(it, TransformIterator(CountingIterator(np.int32(0)), col_major_index))

    columns = ZipIterator(
        *[get_column(data_in, i) for i in range(segment_size)]
    )
    return unary_transform(columns, data_out, op, num_segments)