File: triton_split_scan.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (206 lines) | stat: -rw-r--r-- 7,118 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# mypy: allow-untyped-defs
import functools
from typing import Dict

import sympy

from torch._inductor import config
from torch._inductor.codegen.simd import IterationRangesRoot
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
from torch._inductor.runtime.triton_heuristics import split_scan_grid
from torch.utils._sympy.functions import CeilDiv

from ..utils import sympy_product
from .simd import prefix_is_reduction


class TritonSplitScanKernel(TritonKernel):
    """Generates a triton kernel that supports ops.scan calls while also splitting
    the reduction dimension over multiple triton programs.

    For this kernel, loop numels will always take the form ``(xdim, rdim)``
    and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication
    between blocks occurs within a global memory workspace buffer, which
    must be zero-filled before launching the kernel.

    Note that generation for ``ops.reduction`` is not supported.

    For details of the communication strategy, see
    https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back

    """

    def __init__(
        self,
        tiling: Dict[str, sympy.Expr],
        pid_cache=None,
        fixed_config=None,
        **kwargs,
    ) -> None:
        assert pid_cache is None, "not supported"
        assert fixed_config is None, "not supported"
        super().__init__(
            tiling,
            **kwargs,
        )
        self.no_x_dim = True

    def should_use_persistent_reduction(self) -> bool:
        return False

    def should_use_cooperative_reduction(self) -> bool:
        return False

    def initialize_range_tree(self, pid_cache):
        prefixes = "yxr"
        assert len(self.numels) <= len(
            prefixes
        ), "z dimension not supported for split scan"
        active_prefixes = prefixes[len(prefixes) - len(self.numels) :]

        grid_dims = "rxy"
        for prefix in active_prefixes:
            numel = self.numels[prefix]
            is_reduction = prefix == "r"
            tensor_dim = 0 if is_reduction else None
            grid_dim = grid_dims.find(prefix)
            self.range_trees.append(
                IterationRangesRoot(
                    f"{prefix}index",
                    numel,
                    prefix,
                    grid_dim,
                    self,
                    pid_cache=pid_cache,
                    is_loop=False,
                    tensor_dim=tensor_dim,
                    grid_dim=grid_dim,
                    has_zdim=False,
                )
            )

    def reduction(self, dtype, src_dtype, reduction_type, value):
        raise NotImplementedError("NYI TritonSplitDimKernel reductions")

    def scan(self, dtypes, combine_fn, values):
        import triton.language as tl

        (dtype,) = dtypes
        (value,) = values

        compute_type = triton_compute_type(dtype)
        compute_type_triton = getattr(tl, compute_type[3:])

        element_nbits = compute_type_triton.primitive_bitwidth

        scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64"
        scratch_type_triton = getattr(tl, scratch_type[3:])
        scratch_elems_per_block = 3 if element_nbits == 64 else 1
        scratch_nbytes_per_block = scratch_elems_per_block * (
            scratch_type_triton.primitive_bitwidth // 8
        )

        cse_load = functools.partial(self.cse.generate, self.loads, dtype=dtype)
        cse_compute = functools.partial(self.cse.generate, self.compute)

        assert len(self.numels) == 2, "Unexpected tiling"
        min_rblock = config.triton.min_split_scan_rblock
        reduction_numel = sympy_product(
            numel
            for prefix, numel in self.numels.items()
            if prefix_is_reduction(prefix)
        )
        pointwise_numel = sympy_product(
            numel
            for prefix, numel in self.numels.items()
            if not prefix_is_reduction(prefix)
        )
        max_blocks = pointwise_numel * CeilDiv(reduction_numel, min_rblock)
        nbytes = scratch_nbytes_per_block * max_blocks
        scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True)
        if offset != 0:
            scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}")
        runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})")
        scratch_base = cse_load(
            f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * "
            f"{scratch_elems_per_block} * {runtime_rblocks}"
        )

        masks = {f"{tree.prefix}mask" for tree in self.range_trees}
        self.filter_masks(masks)
        assert not self._load_mask, "ops.scan not supported inside ops.masked"

        value = cse_compute(
            f"{value}.to({compute_type})",
            dtype=dtype,
        )
        value = cse_compute(
            f"tl.broadcast_to({value}, {self.dense_size_str()})",
            dtype=dtype,
        )

        combine_helper_fn = self._lift_helper(combine_fn, 1)
        dim = self.triton_tensor_ndim() - 1
        assert dim == 0, ""

        block_sum = cse_compute(
            f"tl.reduce({value}, {dim}, {combine_helper_fn})",
            dtype=dtype,
        )
        exclusive_prefix = self.cse.newvar(
            dtype=dtype,
        )
        if element_nbits == 64:
            self.compute.splice(
                f"""
                {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64(
                    {scratch_base},
                    {block_sum},
                    {self.iteration_ranges_get_pid(self.range_trees[-1])},
                    {combine_helper_fn},
                )
                """,
                strip=True,
            )

        else:
            assert element_nbits <= 32
            value_as_uint_dtype = f"tl.uint{element_nbits}"

            self.compute.splice(
                f"""
                {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback(
                    {scratch_base},
                    {block_sum},
                    {self.iteration_ranges_get_pid(self.range_trees[-1])},
                    {combine_helper_fn},
                    DTYPE_VALUE_AS_UINT={value_as_uint_dtype},
                    DTYPE_PACK={scratch_type},
                )
                """,
                strip=True,
            )
        # Compute final cumsum
        block_scan = cse_compute(
            f"tl.associative_scan({value}, {dim}, {combine_helper_fn})",
            dtype=dtype,
        )
        combined_result = cse_compute(
            f"{combine_helper_fn}({exclusive_prefix}, {block_scan})",
            dtype=dtype,
        )
        return (
            cse_compute(
                f"tl.where(roffset == 0, {block_scan}, {combined_result})",
                dtype=dtype,
            ),
        )

    def _get_heuristic(self):
        return "split_scan"

    def _get_grid_fn_str(self):
        return "split_scan_grid"

    def _get_grid_fn(self):
        return split_scan_grid