1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
|
// Inter-block reduction.
//
// The gridReduce function performs point-wise reductions of scalars across
// thread blocks. Thread blocks are disjointly partitioned into groups,
// "reduction segments", that are collectively defined by boolean template
// parameters, X_BLOCK, Y_BLOCK and Z_BLOCK. Each of X/Y/Z_BLOCK determines
// whether thread blocks along the dimension should be grouped into the same
// reduction segment. Cross-block reducitons are independently done within each
// segment and generates distinctive results per segment. For instance, if all
// of X/Y/Z_BLOCK are true, reductions will be done across all thread blocks
// since there will be just a single segment consisting of all thread blocks. If
// none of them are true, each thread block will become a segment by itself, so
// no reduction will be performed.
//
// The input scalars to reduce within each segment are a certain subset of
// thread-private scalars provided as part of the gridReduce function
// parameters. Boolean template parameters, X_THREAD, Y_THREAD and Z_THREAD,
// determine which subset of the scalars should be used for inter-block
// reductions. Specifically, all the input scalars of threads along each
// dimension will be used when X/Y/Z_THREAD are true. Otherwise, only the value
// held at offset 0 of each dimension will be used. Thus, for example, if all of
// X/Y/Z_THREAD are true, the scalars of all threads in each block will
// participate in inter-block reductions. If all of them are false, only one
// scalar of the thread at threadIdx.x == threadIdx.y == threadIdx.z == 0 will
// be used. In the code below, we call the subset of threads a "reduction
// block". "Participating" thread dimensions here are similar to the
// "non-participating" block dimensions. They come from a block dimension that
// has not been reduced before hitting this grid reduction.
//
// Inter-block reductions perform point-wise reductions of scalars of reduction
// blocks within each reduction segment. More specifically, let rb be a
// reduction block and rs be a reduction segment. Let IN(thread_idx, block_idx)
// denote the input scalar of thread at thread_idx and block_idx. The result of
// each reduction segment, OUT(thread_idx, block_idx_out), is defined only for
// each thread_idx in thread block block_idx_out in the segment as follows:
//
// OUT(thread_idx, block_idx_out) =
// Reduction of IN(thread_idx, block_idx) for
// all block_idx in a reduction segment
//
// OUT is not given for all threads that are not in block_idx_out and the
// reduction block.
//
// See also the function comment of gridReduce.
namespace reduction {
// Reduces all the reduction blocks in each reduction segment. This is the
// "cleanup" stage of a grid reduction.
//
// This is only called by one thread block per reduction segment. The input
// reduction blocks of the segment are stored in an intermediate buffer pointed
// by parameter in. Template parameters X/Y/Z_THREAD denote how the reduction
// block is formed.
//
// The size of a reduction block is by definition smaller or equal to the size
// of a thread block. We use the remaining threads to parallelize reductions
// across reduction blocks. For example, when X/Y/Z_THREAD = {true, false,
// false}, we use blockDim.y*blockDim.z threads for each output value. This is
// done first by loading the input values in parallel and then by reducing
// across threads of dimensions whose XYZ_THREAD are false.
//
// Note that what is done here after the loading from global memory is similar
// to what the existing blockReduce function does.
template <
bool X_THREAD,
bool Y_THREAD,
bool Z_THREAD,
typename T,
typename Func>
__device__ void gridReduceLastBlock(
T& out,
const volatile T* in,
const nvfuser_index_t
grid_reduction_segment_size, // Number of reductions across
// grid reduce dimensions
const nvfuser_index_t
block_reduction_segment_size, // Number of reductions across the block
Func reduction_op,
T* shared_buf,
bool write_pred,
T init_val) {
// We have to do num_reductions across reduction_size. The reductions are
// contiguous, but offset by reduction_size. There is an entry in "in" for
// every block, and every thread marked as true. Threads in dimensions marked
// as false can be used to parallelize the reduction.
// Find the reduction id of the participating threads
const auto block_reduction_segment_idx =
index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>(
threadIdx, blockDim);
// Find an id associated within a reduction segment for all
// "non-participating" threads, which will parallelize the reductions for the
// "participating" threads
const auto id_in_block_segment =
index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
threadIdx, blockDim);
// Stride by the "non-participating" threads
const auto input_stride_for_thread_in_segment =
index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
T inp = init_val;
// Block stride across the reduction until we only have one value per thread
for (nvfuser_index_t reduction_i = id_in_block_segment;
reduction_i < grid_reduction_segment_size;
reduction_i += input_stride_for_thread_in_segment) {
auto work_buf_offset = reduction_i * block_reduction_segment_size +
block_reduction_segment_idx;
reduction_op(inp, in[work_buf_offset]);
}
// Block reduce the per thread values into per "participating" thread values
T inp_tmp = init_val;
blockReduce<!X_THREAD, !Y_THREAD, !Z_THREAD>(
inp_tmp,
inp,
reduction_op,
threadIdx,
blockDim,
shared_buf,
true,
init_val);
const bool should_write = (X_THREAD || threadIdx.x == 0) &&
(Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0);
if (should_write && write_pred) {
reduction_op(out, inp_tmp);
}
}
// Reduces per-thread values across threads and thread blocks.
//
// Function parameters:
// - out: Per-thread output location
// - inp_val: Per-thread input value
// - reduction_op: Scalar reduction function
// - work_buf: Temporary buffer for cross-block reductions
// - sync_flags: A vector of integers for synchronizations
// - shared_buf: Shared memory buffer for intra-block reduction
//
// Thread has valid results based on if it's the last block in the grid
// reduction dimension
//
// Template parameters:
// - X/Y/Z_BLOCK/THREAD: When true, reduces across thread blocks along the X/Y/Z
// dimensions
// - PERSISTENT_REDUCTION: Indicates grid reduction will be called in a loop, or
// the result of the grid reduction will be broadcasted and used across the
// grid. These requires cross grid communication and the grid synchronizations
// here to actually synchronize across the entire grid. When false the grid is
// not synchronized, the last block just waits for everyone else to finish and
// the other blocks can exit early.
// - T: Scalar data type of input/output data
// - Func: Type of scalara reduction function
//
// Template parameters X/Y/Z_BLOCK define a group of thread blocks that are
// reduced together. We call it a reduction segment. Some examples are:
//
// Case 1: X/Y/Z_BLOCK == true/true/true -> There is only one segment, which
// includes all thread blocks. It is effecively the same as the grid.
//
// Case 2: X/Y/Z_BLOCK == false/false/false -> Each thread block comprises an
// individual segment by itself.
//
// Case 3: X/Y/Z_BLOCK == true/false/false -> Each segment contains thread
// blocks that have the same blockDim.x. There will be blockDim.y*blockDim.z
// such segments.
//
// X/Y/Z_THREAD also works similarly as X/Y/Z_BLOCK and defines a
// group of threads that are reduced togather.
//
// After the function completes, only one thread block per reduction segment
// gets valid reduction results. There is no guarantee which particular block
// gets the final results.
//
// entrance_ind and n_entrances are allowed when PERSISTENT_REDUCTION = false.
// If a grid reduction call is only called once per thread, entrance_ind == 0
// and n_entrances == 1. However, grid reduction can be called in a loop in a
// thread, in that case entrance_ind is the count of times the function has been
// called, and n_entrances is the total number of times it will be called.
template <
bool X_BLOCK,
bool Y_BLOCK,
bool Z_BLOCK,
bool X_THREAD,
bool Y_THREAD,
bool Z_THREAD,
bool PERSISTENT_REDUCTION,
typename T,
typename Func>
__device__ void gridReduce(
T& out,
const T& inp_val,
Func reduction_op,
volatile T* work_buf,
int64_t* sync_flags,
T* shared_buf,
bool read_pred,
bool write_pred,
T init_val,
const nvfuser_index_t entrance_ind,
const nvfuser_index_t n_entrances) {
T block_reduction_val = init_val;
// Do block reduction when required
if (X_THREAD || Y_THREAD || Z_THREAD) {
blockReduce<X_THREAD, Y_THREAD, Z_THREAD>(
block_reduction_val,
inp_val,
reduction_op,
threadIdx,
blockDim,
shared_buf,
read_pred,
true,
init_val);
} else if (read_pred) {
block_reduction_val = inp_val;
}
// Number of values to reduce in the reduction segment
const auto grid_reduction_segment_size =
index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
// Index of the reduction we're performing out of the
// grid_reduction_segment_size
const auto idx_in_grid_segment =
index_utils::maskedOffset<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(
blockIdx, gridDim);
// Number of threads we can use in final reduction, Seems to assume all
// threads in the block participate
const auto block_reduction_segment_size =
index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
// Number of reductions in the grid
const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION
? 1
: index_utils::maskedSize<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(gridDim);
// advance to the offset for this segment
// index of reduction * size of the reduction * size of threads
work_buf += (entrance_ind * grid_segment_size + idx_in_grid_segment) *
grid_reduction_segment_size * block_reduction_segment_size;
if ((!X_THREAD || threadIdx.x == 0) && (!Y_THREAD || threadIdx.y == 0) &&
(!Z_THREAD || threadIdx.z == 0)) {
auto block_offset =
index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
auto thread_offset =
index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
threadIdx, blockDim);
auto work_buf_offset =
block_offset * block_reduction_segment_size + thread_offset;
work_buf[work_buf_offset] = block_reduction_val;
}
if (PERSISTENT_REDUCTION) {
grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
} else {
// Use a different sync flag for each call
grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
sync_flags[entrance_ind * grid_segment_size + idx_in_grid_segment],
grid_reduction_segment_size);
}
bool last_block =
index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
if (last_block) {
// Cleanup with block reduction
gridReduceLastBlock<!X_THREAD, !Y_THREAD, !Z_THREAD>(
out,
(T*)work_buf,
grid_reduction_segment_size,
block_reduction_segment_size,
reduction_op,
shared_buf,
write_pred,
init_val);
}
if (PERSISTENT_REDUCTION) {
// Make sure we're done with global memory before we allow the kernel to
// continue
grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
}
}
// This is just a wrapper of the above grid reduction routine to
// measure the elapsed cycles. The measurement must be done just by
// one thread, and in this case it should be done by one of the
// threads in the last thread block.
#ifdef PYTORCH_NVFUSER_PROFILE_KERNEL
template <
bool X_BLOCK,
bool Y_BLOCK,
bool Z_BLOCK,
bool X_THREAD,
bool Y_THREAD,
bool Z_THREAD,
bool PERSISTENT_REDUCTION,
typename T,
typename Func>
__device__ void gridReduce(
T& out,
const T& inp_val,
Func reduction_op,
volatile T* work_buf,
int64_t* sync_flags,
T* shared_buf,
bool read_pred,
bool write_pred,
T init_val,
const nvfuser_index_t entrance_ind,
const nvfuser_index_t n_entrances,
int64_t& cycles,
int64_t& count) {
int64_t start_counter = 0;
if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) &&
index_utils::maskedIsZero<true, true, true>(threadIdx)) {
start_counter = readCycleCounter();
}
gridReduce<
X_BLOCK,
Y_BLOCK,
Z_BLOCK,
X_THREAD,
Y_THREAD,
Z_THREAD,
PERSISTENT_REDUCTION,
T,
Func>(
out,
inp_val,
reduction_op,
work_buf,
sync_flags,
shared_buf,
read_pred,
write_pred,
init_val,
entrance_ind,
n_entrances);
if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) &&
index_utils::maskedIsZero<true, true, true>(threadIdx)) {
cycles += readCycleCounter() - start_counter;
++count;
}
}
#endif // PYTORCH_NVFUSER_PROFILE_KERNEL
template <
bool X_BLOCK,
bool Y_BLOCK,
bool Z_BLOCK,
bool X_THREAD,
bool Y_THREAD,
bool Z_THREAD,
typename T,
typename Func>
__device__ void gridReduce2PartialReduction(
const T& inp_val,
T init_val,
Func reduction_op,
volatile T* work_buf,
T* shared_buf,
bool read_pred,
nvfuser_index_t grid_reduction_segment_size,
nvfuser_index_t idx_in_grid_segment,
nvfuser_index_t block_reduction_segment_size) {
T block_reduction_val = init_val;
// Do block reduction when required
if (X_THREAD || Y_THREAD || Z_THREAD) {
blockReduce<X_THREAD, Y_THREAD, Z_THREAD>(
block_reduction_val,
inp_val,
reduction_op,
threadIdx,
blockDim,
shared_buf,
read_pred,
true,
init_val);
} else if (read_pred) {
block_reduction_val = inp_val;
}
if ((!X_THREAD || threadIdx.x == 0) && (!Y_THREAD || threadIdx.y == 0) &&
(!Z_THREAD || threadIdx.z == 0)) {
auto block_offset =
index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
auto thread_offset =
index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
threadIdx, blockDim);
auto work_buf_offset =
block_offset * block_reduction_segment_size + thread_offset;
work_buf[work_buf_offset] = block_reduction_val;
}
}
// 2-way horizontally fused grid reduction
template <
bool X_BLOCK,
bool Y_BLOCK,
bool Z_BLOCK,
bool X_THREAD,
bool Y_THREAD,
bool Z_THREAD,
bool PERSISTENT_REDUCTION,
typename T1,
typename Func1,
typename T2,
typename Func2>
__device__ void gridReduceGroup(
T1& out1,
const T1& inp_val1,
T1 init_val1,
Func1 reduction_op1,
volatile T1* work_buf1,
T2& out2,
const T2& inp_val2,
T2 init_val2,
Func2 reduction_op2,
volatile T2* work_buf2,
int64_t* sync_flags,
void* shared_buf,
bool read_pred,
bool write_pred,
const nvfuser_index_t entrance_ind,
const nvfuser_index_t n_entrances) {
// Number of values to reduce in the reduction segment
const auto grid_reduction_segment_size =
index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
// Index of the reduction we're performing out of the
// grid_reduction_segment_size
const auto idx_in_grid_segment =
index_utils::maskedOffset<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(
blockIdx, gridDim);
// Number of threads we can use in final reduction, Seems to assume all
// threads in the block participate
const auto block_reduction_segment_size =
index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
// Number of reductions in the grid
const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION
? 1
: index_utils::maskedSize<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(gridDim);
// advance to the offset for this segment
// index of reduction * size of the reduction * size of threads
work_buf1 += (entrance_ind * grid_segment_size + idx_in_grid_segment) *
grid_reduction_segment_size * block_reduction_segment_size;
work_buf2 += (entrance_ind * grid_segment_size + idx_in_grid_segment) *
grid_reduction_segment_size * block_reduction_segment_size;
gridReduce2PartialReduction<
X_BLOCK,
Y_BLOCK,
Z_BLOCK,
X_THREAD,
Y_THREAD,
Z_THREAD>(
inp_val1,
init_val1,
reduction_op1,
work_buf1,
(T1*)shared_buf,
read_pred,
grid_reduction_segment_size,
idx_in_grid_segment,
block_reduction_segment_size);
gridReduce2PartialReduction<
X_BLOCK,
Y_BLOCK,
Z_BLOCK,
X_THREAD,
Y_THREAD,
Z_THREAD>(
inp_val2,
init_val2,
reduction_op2,
work_buf2,
(T2*)shared_buf,
read_pred,
grid_reduction_segment_size,
idx_in_grid_segment,
block_reduction_segment_size);
if (PERSISTENT_REDUCTION) {
grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
} else {
grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
sync_flags[entrance_ind * grid_segment_size + idx_in_grid_segment],
grid_reduction_segment_size);
}
bool last_block =
index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
if (last_block) {
// Cleanup with block reduction
gridReduceLastBlock<!X_THREAD, !Y_THREAD, !Z_THREAD>(
out1,
work_buf1,
grid_reduction_segment_size,
block_reduction_segment_size,
reduction_op1,
(T1*)shared_buf,
write_pred,
init_val1);
gridReduceLastBlock<!X_THREAD, !Y_THREAD, !Z_THREAD>(
out2,
work_buf2,
grid_reduction_segment_size,
block_reduction_segment_size,
reduction_op2,
(T2*)shared_buf,
write_pred,
init_val2);
}
if (PERSISTENT_REDUCTION) {
// Make sure we're done with global memory before we allow the kernel to
// continue
grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
}
}
#ifdef PYTORCH_NVFUSER_PROFILE_KERNEL
template <
bool X_BLOCK,
bool Y_BLOCK,
bool Z_BLOCK,
bool X_THREAD,
bool Y_THREAD,
bool Z_THREAD,
bool PERSISTENT_REDUCTION,
typename T1,
typename Func1,
typename T2,
typename Func2>
__device__ void gridReduceGroup(
T1& out1,
const T1& inp_val1,
T1 init_val1,
Func1 reduction_op1,
volatile T1* work_buf1,
T2& out2,
const T2& inp_val2,
T2 init_val2,
Func2 reduction_op2,
volatile T2* work_buf2,
int64_t* sync_flags,
void* shared_buf,
bool read_pred,
bool write_pred,
const nvfuser_index_t entrance_ind,
const nvfuser_index_t n_entrances,
int64_t& cycles,
int64_t& count) {
int64_t start_counter = 0;
if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) &&
index_utils::maskedIsZero<true, true, true>(threadIdx)) {
start_counter = readCycleCounter();
}
gridReduceGroup<
X_BLOCK,
Y_BLOCK,
Z_BLOCK,
X_THREAD,
Y_THREAD,
Z_THREAD,
PERSISTENT_REDUCTION,
T1,
Func1,
T2,
Func2>(
out1,
inp_val1,
init_val1,
reduction_op1,
work_buf1,
out2,
inp_val2,
init_val2,
reduction_op2,
work_buf2,
sync_flags,
shared_buf,
read_pred,
write_pred,
entrance_ind,
n_entrances);
if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) &&
index_utils::maskedIsZero<true, true, true>(threadIdx)) {
cycles += readCycleCounter() - start_counter;
++count;
}
}
#endif // PYTORCH_NVFUSER_PROFILE_KERNEL
} // namespace reduction
|