1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
|
# mypy: allow-untyped-defs
import functools
from typing import Dict
import sympy
from torch._inductor import config
from torch._inductor.codegen.simd import IterationRangesRoot
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
from torch._inductor.runtime.triton_heuristics import split_scan_grid
from torch.utils._sympy.functions import CeilDiv
from ..utils import sympy_product
from .simd import prefix_is_reduction
class TritonSplitScanKernel(TritonKernel):
"""Generates a triton kernel that supports ops.scan calls while also splitting
the reduction dimension over multiple triton programs.
For this kernel, loop numels will always take the form ``(xdim, rdim)``
and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication
between blocks occurs within a global memory workspace buffer, which
must be zero-filled before launching the kernel.
Note that generation for ``ops.reduction`` is not supported.
For details of the communication strategy, see
https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
"""
def __init__(
self,
tiling: Dict[str, sympy.Expr],
pid_cache=None,
fixed_config=None,
**kwargs,
) -> None:
assert pid_cache is None, "not supported"
assert fixed_config is None, "not supported"
super().__init__(
tiling,
**kwargs,
)
self.no_x_dim = True
def should_use_persistent_reduction(self) -> bool:
return False
def should_use_cooperative_reduction(self) -> bool:
return False
def initialize_range_tree(self, pid_cache):
prefixes = "yxr"
assert len(self.numels) <= len(
prefixes
), "z dimension not supported for split scan"
active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
grid_dims = "rxy"
for prefix in active_prefixes:
numel = self.numels[prefix]
is_reduction = prefix == "r"
tensor_dim = 0 if is_reduction else None
grid_dim = grid_dims.find(prefix)
self.range_trees.append(
IterationRangesRoot(
f"{prefix}index",
numel,
prefix,
grid_dim,
self,
pid_cache=pid_cache,
is_loop=False,
tensor_dim=tensor_dim,
grid_dim=grid_dim,
has_zdim=False,
)
)
def reduction(self, dtype, src_dtype, reduction_type, value):
raise NotImplementedError("NYI TritonSplitDimKernel reductions")
def scan(self, dtypes, combine_fn, values):
import triton.language as tl
(dtype,) = dtypes
(value,) = values
compute_type = triton_compute_type(dtype)
compute_type_triton = getattr(tl, compute_type[3:])
element_nbits = compute_type_triton.primitive_bitwidth
scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64"
scratch_type_triton = getattr(tl, scratch_type[3:])
scratch_elems_per_block = 3 if element_nbits == 64 else 1
scratch_nbytes_per_block = scratch_elems_per_block * (
scratch_type_triton.primitive_bitwidth // 8
)
cse_load = functools.partial(self.cse.generate, self.loads, dtype=dtype)
cse_compute = functools.partial(self.cse.generate, self.compute)
assert len(self.numels) == 2, "Unexpected tiling"
min_rblock = config.triton.min_split_scan_rblock
reduction_numel = sympy_product(
numel
for prefix, numel in self.numels.items()
if prefix_is_reduction(prefix)
)
pointwise_numel = sympy_product(
numel
for prefix, numel in self.numels.items()
if not prefix_is_reduction(prefix)
)
max_blocks = pointwise_numel * CeilDiv(reduction_numel, min_rblock)
nbytes = scratch_nbytes_per_block * max_blocks
scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True)
if offset != 0:
scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}")
runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})")
scratch_base = cse_load(
f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * "
f"{scratch_elems_per_block} * {runtime_rblocks}"
)
masks = {f"{tree.prefix}mask" for tree in self.range_trees}
self.filter_masks(masks)
assert not self._load_mask, "ops.scan not supported inside ops.masked"
value = cse_compute(
f"{value}.to({compute_type})",
dtype=dtype,
)
value = cse_compute(
f"tl.broadcast_to({value}, {self.dense_size_str()})",
dtype=dtype,
)
combine_helper_fn = self._lift_helper(combine_fn, 1)
dim = self.triton_tensor_ndim() - 1
assert dim == 0, ""
block_sum = cse_compute(
f"tl.reduce({value}, {dim}, {combine_helper_fn})",
dtype=dtype,
)
exclusive_prefix = self.cse.newvar(
dtype=dtype,
)
if element_nbits == 64:
self.compute.splice(
f"""
{exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64(
{scratch_base},
{block_sum},
{self.iteration_ranges_get_pid(self.range_trees[-1])},
{combine_helper_fn},
)
""",
strip=True,
)
else:
assert element_nbits <= 32
value_as_uint_dtype = f"tl.uint{element_nbits}"
self.compute.splice(
f"""
{exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback(
{scratch_base},
{block_sum},
{self.iteration_ranges_get_pid(self.range_trees[-1])},
{combine_helper_fn},
DTYPE_VALUE_AS_UINT={value_as_uint_dtype},
DTYPE_PACK={scratch_type},
)
""",
strip=True,
)
# Compute final cumsum
block_scan = cse_compute(
f"tl.associative_scan({value}, {dim}, {combine_helper_fn})",
dtype=dtype,
)
combined_result = cse_compute(
f"{combine_helper_fn}({exclusive_prefix}, {block_scan})",
dtype=dtype,
)
return (
cse_compute(
f"tl.where(roffset == 0, {block_scan}, {combined_result})",
dtype=dtype,
),
)
def _get_heuristic(self):
return "split_scan"
def _get_grid_fn_str(self):
return "split_scan_grid"
def _get_grid_fn(self):
return split_scan_grid
|