1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
|
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import cupy as cp
import numpy as np
import nvtx
from cuda.compute import (
PermutationIterator, ZipIterator, TransformIterator, CountingIterator,
OpKind, exclusive_scan, select, unary_transform
)
@nvtx.annotate("segment_sizes")
def segment_sizes(offsets):
"""
Compute the size of each segment from segment offsets.
Args:
offsets: Device array of segment offsets (length = num_segments + 1).
Each segment i contains elements from offsets[i] to offsets[i+1].
Returns:
Device array of segment sizes (length = num_segments).
"""
return offsets[1:] - offsets[:-1]
@nvtx.annotate("offsets_to_segment_ids")
def offsets_to_segment_ids(offsets, stream=None):
"""
Convert segment offsets to segment IDs (indicators).
Given offsets [0, 2, 5, 8, 10], produces [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
"""
num_elements = int(offsets[-1])
if num_elements == 0:
return cp.array([], dtype=np.int32)
# TODO: when available this can be a fused CountingIterator + lower_bound
# Create array of all element indices [0, 1, 2, ..., num_elements-1]
element_indices = cp.arange(num_elements, dtype=np.int32)
# Use binary search to find which segment each element belongs to
# searchsorted finds the rightmost position where we can insert each element index
# such that the array remains sorted. Subtracting 1 gives us the segment ID.
#
# Example: offsets = [0, 0, 2, 3], element_indices = [0, 1, 2]
# searchsorted(offsets[:-1], [0, 1, 2], side='right') = [1, 1, 2]
# Subtracting 1: [0, 0, 1] - but wait, element 0 should be in segment 1!
#
# Actually, we want: offsets[j] <= i < offsets[j+1]
# So searchsorted with side='right' on offsets (not offsets[:-1]) and subtract 1
segment_ids = cp.searchsorted(offsets, element_indices, side="right") - 1
return segment_ids
@nvtx.annotate("select_segments")
def select_segments(
data_in,
offsets_in,
mask_in,
data_out,
offsets_out,
d_num_selected_out,
num_elements,
num_segments,
stream=None,
):
"""
Select segments of a segmented array using a per-segment mask.
A segmented array is conceptually composed of a data array and segment offsets.
For example, with data=[30, 20, 20, 50, 90, 10, 30, 80, 20, 60] and
offsets=[0, 2, 5, 8, 10], the segmented array represents:
[[30, 20], [20, 50, 90], [10, 30, 80], [20, 60]].
Given the mask [0, 1, 0, 1], the function will return the segmented array
[[20, 50, 90], [20, 60]].
Example:
>>> data_in = cp.array([30, 20, 20, 50, 90, 10, 30, 80, 20, 60], dtype=np.int32)
>>> offsets_in = cp.array([0, 2, 5, 8, 10], dtype=np.int32)
>>> mask_in = cp.array([0, 1, 0, 1], dtype=np.int8)
>>> data_out = cp.empty_like(data_in)
>>> offsets_out = cp.empty_like(offsets_in)
>>> d_num_selected = cp.zeros(2, dtype=np.int32)
>>> select_segments(data_in, offsets_in, mask_in, data_out, offsets_out,
... d_num_selected, len(data_in), len(offsets_in) - 1)
>>> # Result: data_out contains [20, 50, 90, 20, 60, ...]
>>> # offsets_out contains [0, 3, 5, ...]
>>> # d_num_selected[0] = 5 (number of data elements)
>>> # d_num_selected[1] = 2 (number of segments kept)
Args:
data_in: Device array or iterator containing all segment elements concatenated.
offsets_in: Device array or iterator of segment offsets (length = num_segments + 1).
Each segment i contains elements from offsets[i] to offsets[i+1].
mask_in: Device array or iterator (int8) indicating which segments to keep (length = num_segments).
Non-zero values indicate segments to keep.
data_out: Device array or iterator to store selected data elements.
Should be pre-allocated with size >= num_elements.
offsets_out: Device array or iterator to store new segment offsets.
Should be pre-allocated with size >= num_segments + 1.
d_num_selected_out: Device array to store counts (length >= 2):
- d_num_selected_out[0]: number of selected data elements
- d_num_selected_out[1]: number of segments kept
num_elements: Total number of elements in data_in.
num_segments: Total number of segments (= len(offsets_in) - 1).
stream: CUDA stream for the operation (optional).
"""
# Step 1: Create segment_indices array indicating which segment each element belongs to
segment_indices = offsets_to_segment_ids(offsets_in, stream)
# Step 2: Expand mask from per-segment to per-element using PermutationIterator
# Each element gets the mask value of its corresponding segment
expanded_mask_it = PermutationIterator(mask_in, segment_indices)
# Step 3: Filter the data array and capture indices in a single select call
# Zip together data, expanded mask, and counting iterator
data_mask_idx_in = ZipIterator(
data_in, expanded_mask_it, CountingIterator(np.int32(0)))
d_selected_indices = cp.empty(num_elements, dtype=np.int32)
data_idx_out = ZipIterator(data_out, d_selected_indices)
d_num_data_selected = cp.zeros(1, dtype=np.int32)
# Define predicate that checks if mask value is non-zero
def mask_predicate(triple):
return triple[1] != 0
# Apply select to get both data and indices where mask is non-zero
select(
data_mask_idx_in,
data_idx_out,
d_num_data_selected,
mask_predicate,
num_elements,
stream,
)
# Get the actual number of selected elements
num_selected = int(d_num_data_selected[0])
d_selected_indices = d_selected_indices[:num_selected]
# Step 4: Compute new segment offsets using the captured indices
# TODO: this part should use run_length_encode when available
# Use searchsorted to count elements per segment
# Use side='left' to count elements strictly less than each offset boundary
positions = cp.searchsorted(d_selected_indices, offsets_in, side='left')
segment_counts = (positions[1:] - positions[:-1]).astype(np.int32)
# Select out the segment sizes where mask is non-zero (those segments are kept)
# Convert mask to a regular array if needed and use boolean indexing
mask_array = cp.asarray(mask_in)
kept_segment_sizes = segment_counts[mask_array != 0]
num_kept_segments = len(kept_segment_sizes)
# Exclusive scan to convert sizes to offsets
temp_offsets = cp.zeros(num_kept_segments + 1, dtype=np.int32)
if num_kept_segments > 0:
h_init_scan = np.array([0], dtype=np.int32)
exclusive_scan(
kept_segment_sizes,
offsets_out,
OpKind.PLUS,
h_init_scan,
num_kept_segments,
stream,
)
# Set the final offset to the total number of selected elements
offsets_out[num_kept_segments] = num_selected
# Store the counts in d_num_selected_out
d_num_selected_out[0] = num_selected # number of data elements
d_num_selected_out[1] = num_kept_segments # number of segments kept
@nvtx.annotate("segmented_select")
def segmented_select(
d_in_data,
d_in_segments,
d_out_data,
d_out_segments,
cond,
num_items: int,
stream=None,
) -> int:
"""
Select data within segments independently based on a condition.
Given segmented input data and a selection condition, this function
applies the selection to each segment independently and produces compacted
output with updated segment offsets.
Args:
d_in_data: Device array containing the input data items.
d_in_segments: Device array of segment offsets. For N segments,
this array has N+1 elements where segments[i:i+1] defines
the range [start, end) for segment i.
d_out_data: Device array to store selected data (pre-allocated,
should be at least as large as d_in_data).
d_out_segments: Device array to store output segment offsets
(pre-allocated, same size as d_in_segments).
cond: Callable that takes a data item and returns a boolean-like
value (typically uint8) indicating whether to keep the item.
num_items: Total number of items in d_in_data.
stream: CUDA stream for the operation (optional).
Returns:
int: Total number of items after selection (equal to d_out_segments[-1]).
Example:
>>> # Input: [[45], [25, 35], [15]] with condition x > 30
>>> # Output: [[45], [35], []] -> offsets [0, 1, 2, 2]
>>> def greater_than_30(x):
... return x > 30
>>> d_in_data = cp.array([45, 25, 35, 15], dtype=cp.int32)
>>> d_in_segments = cp.array([0, 1, 3, 4], dtype=cp.int32)
>>> d_out_data = cp.empty_like(d_in_data)
>>> d_out_segments = cp.empty_like(d_in_segments)
>>> total = segmented_select(
... d_in_data, d_in_segments, d_out_data, d_out_segments,
... greater_than_30, len(d_in_data)
... )
>>> print(total) # 2
>>> print(d_out_segments.get()) # [0, 1, 2, 2]
"""
import numba.cuda
num_segments = len(d_in_segments) - 1
cond = numba.cuda.jit(cond)
# Step 1: Apply select to get the data and indices where condition is true
def select_predicate(pair):
return cond(pair[0])
data_idx_in = ZipIterator(d_in_data, CountingIterator(np.int32(0)))
d_indices_out = cp.empty(num_items, dtype=np.int32)
data_idx_out = ZipIterator(d_out_data, d_indices_out)
d_num_selected = cp.zeros(1, dtype=cp.uint64)
select(data_idx_in, data_idx_out,
d_num_selected, select_predicate, num_items, stream)
total_selected = int(d_num_selected[0])
d_indices_out = d_indices_out[:total_selected]
d_selected_indices = d_indices_out[:total_selected]
# Step 2: Use searchsorted to count selected items per segment
# Use side='left' to count elements strictly less than each offset boundary
positions = cp.searchsorted(
d_selected_indices, d_in_segments, side='left')
d_counts = (positions[1:] - positions[:-1]).astype(cp.uint64)
# Step 3: Use exclusive scan to compute output segment start offsets
exclusive_scan(
d_counts,
d_out_segments[:-1],
OpKind.PLUS,
np.array(0, dtype=np.uint64),
num_segments,
stream,
)
# Set the final offset to the total count
d_out_segments[-1] = total_selected
return total_selected
@nvtx.annotate("transform_segments")
def transform_segments(data_in, data_out, segment_size, op, num_segments):
"""
Given a segmented array where each segment contains the same number of items,
transform each segment independently using the given n-ary operation.
For example, given the segmented array [[1, 2, 3], [4, 5, 6], [7, 8, 9]] and the
operation x + y + z, the function will return the segmented array [[10], [15], [24]].
"""
def get_column(it, i):
# return an iterator representing the i-th column of the segmented array.
def col_major_index(j: np.int32) -> np.int32:
# given the row major index j, return the column major index.
return j * segment_size + i
return PermutationIterator(it, TransformIterator(CountingIterator(np.int32(0)), col_major_index))
columns = ZipIterator(
*[get_column(data_in, i) for i in range(segment_size)]
)
return unary_transform(columns, data_out, op, num_segments)
|