File: triton_combo_kernel.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (1121 lines) | stat: -rw-r--r-- 46,988 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
import itertools
import logging
import textwrap
from collections import defaultdict
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    cast,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

import sympy
from sympy import Integer, Symbol

from .. import config, metrics
from ..runtime.hints import DeviceProperties
from ..runtime.runtime_utils import next_power_of_2
from ..runtime.triton_heuristics import grid_combo_kernels
from ..scheduler import BaseSchedulerNode
from ..utils import Placeholder
from ..virtualized import V
from .common import (
    DeferredLine,
    IndentedBuffer,
    Kernel,
    PythonPrinter,
    SizeArg,
    WorkspaceArg,
)
from .simd import prefix_is_reduction, SIMDScheduling
from .simd_kernel_features import SIMDKernelFeatures
from .triton import gen_common_triton_imports, TritonKernel
from .triton_utils import config_of, signature_to_meta


log = logging.getLogger(__name__)
pexpr = PythonPrinter().doprint
LARGE_NUMELS = 512e5
BLOCK_UTILIZATION = 0.8


def _default_custom_combo_kernel_horizontal_partition(
    nodes: List[BaseSchedulerNode],
    triton_scheduling: SIMDScheduling,
    kernel_map: Dict[BaseSchedulerNode, TritonKernel],
    node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]],
) -> List[List[BaseSchedulerNode]]:
    """Horizontally partition the given list of nodes into a list of list of nodes where each sublist
    represents a partion. Nodes in different partitions are implemented in different combo kernels.
    Nodes in the same partition are likely to be implemented
    in the same combo kernel, but subject to subsequent restrictions like CUDA limits for number of args.

    Input arguments:
        nodes: a list of fused scheduler nodes to partition.
        triton_scheduling: TritonScheduling instance.
        kernel_map: a map from node to its kernel.
        node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel).
    Output:
        a list of list of nodes with each sublist representing a partition.

    The default algorithm is to partition nodes based on the following rules:
        1) nodes with the same number of block dimensions are grouped together.
        2) large pointwise nodes (numels greater than LARGE_NUMELS) are separated from other nodes.
        3) large reduce nodes are separated from other nodes.
    """

    assert len(nodes) >= 1

    # first partition nodes based on number of block dimensions
    tilings = [node_info_map[n][1] for n in nodes]

    max_dims = max(len(t) for t in tilings)
    nodes_per_ndim: List[List[BaseSchedulerNode]] = []
    for i in range(2, max_dims + 1):
        group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i]
        reduction = [
            n
            for n in group_per_dim
            if kernel_map[n].inside_reduction
            and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim)
        ]
        not_reduction = [n for n in group_per_dim if n not in reduction]
        # rnumel > 2048 usually has long execution time
        # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes
        long_reduction = [
            n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048  # type: ignore[arg-type]
        ]
        short_reduction = [n for n in reduction if n not in long_reduction]
        if long_reduction:
            log.warning(
                "ComboKernels: %d long reduction nodes are separated",
                len(long_reduction),
            )
        large_pointwise = [
            n
            for n in not_reduction
            if not kernel_map[n].inside_reduction
            and len(kernel_map[n].numels) == 2
            and V.graph.sizevars.size_hint(kernel_map[n].numels["x"]) > LARGE_NUMELS
        ]
        if large_pointwise:
            # TODO benchmark the performance when large pointwise nodes combining with others
            log.warning(
                "ComboKernels: %d large pointwise nodes are separated",
                len(large_pointwise),
            )
            not_reduction = [n for n in not_reduction if n not in large_pointwise]
            nodes_per_ndim.extend([node] for node in large_pointwise)

        nodes_per_ndim.extend(
            g for g in (not_reduction, short_reduction, long_reduction) if g
        )

    assert sum(len(p) for p in nodes_per_ndim) == len(nodes)
    return nodes_per_ndim


_custom_combo_kernel_horizontal_partition_algorithm: Callable[
    [
        List[BaseSchedulerNode],
        SIMDScheduling,
        Dict[BaseSchedulerNode, TritonKernel],
        Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]],
    ],
    List[List[BaseSchedulerNode]],
] = _default_custom_combo_kernel_horizontal_partition


def set_custom_combo_kernel_horizontal_partition(
    algorithm: Callable[
        [
            List[BaseSchedulerNode],
            SIMDScheduling,
            Dict[BaseSchedulerNode, TritonKernel],
            Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]],
        ],
        List[List[BaseSchedulerNode]],
    ]
) -> None:
    """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
    are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
    in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args.

    The algorithm should take a list of nodes and return a list of list of nodes.

    The default algorithm is to partition nodes based on number of block dimensions.
    """
    global _custom_combo_kernel_horizontal_partition_algorithm
    _custom_combo_kernel_horizontal_partition_algorithm = algorithm


@dataclass
class PartitionState:
    partitions: List[List[BaseSchedulerNode]]
    cur_partition: List[BaseSchedulerNode]
    cur_count: int

    def finalize(self) -> None:
        if self.cur_partition:
            self.partitions.append(self.cur_partition)


class ComboKernel(Kernel):
    MAX_NUM_ARGS = 250  # number where I would no longer get triton errors

    @staticmethod
    def _update_partition(
        partition_state: PartitionState,
        node_rw_count: int,
        node_info: BaseSchedulerNode,
    ) -> None:
        if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS:
            partition_state.partitions.append(partition_state.cur_partition)
            partition_state.cur_partition = [node_info]
            partition_state.cur_count = node_rw_count
        else:
            partition_state.cur_count += node_rw_count
            partition_state.cur_partition.append(node_info)

    @staticmethod
    def _base_horizontal_partition(
        subkernel_nodes: List[BaseSchedulerNode],
        triton_scheduling: SIMDScheduling,
        node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]],
        custom_algorithm: bool,
    ) -> List[List[BaseSchedulerNode]]:
        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
        for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
        (read/writes) and to have the same 2D or 1D blocking strategy."""
        # TODO support combination of kernels with different block dimensions
        assert len(subkernel_nodes) >= 1
        mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or (
            config.combo_kernel_allow_mixed_sizes == 1 and custom_algorithm
        )

        ndim_to_partition_state: Dict[int, PartitionState] = defaultdict(
            lambda: PartitionState([], [], 0)
        )
        yelem_to_partition_state: Dict[int, PartitionState] = defaultdict(
            lambda: PartitionState([], [], 0)
        )

        for node in subkernel_nodes:
            node_schedule, tiled_groups, numel, rnumel = node_info_map[node]
            node_info = node

            read_writes = node.read_writes
            read_write_count = len(read_writes.reads) + len(read_writes.writes)

            ndim = len(tiled_groups)
            assert ndim >= 2, f"Combokernel not support tile {tiled_groups}"
            if not mixed_sizes and ndim == 3:
                y_elem = tiled_groups["y"]
                partition_state = yelem_to_partition_state[y_elem]
                ComboKernel._update_partition(
                    partition_state, read_write_count, node_info
                )
            else:
                assert mixed_sizes or ndim <= 3, f"No mixed sizes: tile {tiled_groups}"
                partition_state = ndim_to_partition_state[ndim]
                ComboKernel._update_partition(
                    partition_state, read_write_count, node_info
                )

        all_partitions = []
        for partition_state in ndim_to_partition_state.values():
            partition_state.finalize()
            all_partitions.extend(partition_state.partitions)
        for partition_state in yelem_to_partition_state.values():
            partition_state.finalize()
            all_partitions.extend(partition_state.partitions)

        return all_partitions

    @staticmethod
    def horizontal_partition(
        nodes: List[BaseSchedulerNode],
        triton_scheduling: SIMDScheduling,
        kernel_map: Dict[BaseSchedulerNode, TritonKernel],
        node_info_map: Dict[BaseSchedulerNode, Tuple[Any, Any, Any, Any]],
        custom_algorithm: bool = False,
    ) -> List[List[BaseSchedulerNode]]:
        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum)
        for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into
        sublists in the following way:
            1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True
            2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is
               guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same
               2D or 1D blocking strategy.
        """
        if custom_algorithm:
            raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm(
                nodes, triton_scheduling, kernel_map, node_info_map
            )
        else:
            raw_partitions = [nodes]

        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
        for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
        (read/writes) and to have the same 2D or 1D blocking strategy."""
        all_partitions = []
        for raw_partition in raw_partitions:
            all_partitions.extend(
                ComboKernel._base_horizontal_partition(
                    raw_partition, triton_scheduling, node_info_map, custom_algorithm
                )
            )
        return all_partitions

    class SequentialDispatch:
        """
        The dispatcher which dispatches the subkernels in a sequential manner:
        the blocks are first dispatched to the 1st subkernel (until it is filled),
        then to the 2nd subkernel, and so on.
        The class defines the methods specific to the dispatch algorithm.
        Methods:
            codegen_pid_range(...): codegen the pid range for each subkernel.
            grid(...): codegen the grid size for launching the combo kernel.
        """

        @classmethod
        def codegen_pid_range(
            cls, kernel: "ComboKernel", num: int, code: IndentedBuffer
        ) -> None:
            if num == 0:
                cls._calculate_xblocks(kernel, code)
                code.splice(f"if pid < num_xblocks_{num}:")
                with code.indent():
                    code.splice("pid_offset = pid")
            else:
                code.splice(f"elif pid < num_xblocks_{num}:")
                with code.indent():
                    code.splice(f"pid_offset = pid - num_xblocks_{num - 1}")

        @classmethod
        def _calculate_xblocks(
            cls, kernel: "ComboKernel", code: IndentedBuffer
        ) -> None:
            x_numels_list = kernel.x_numels_list
            for i in range(len(x_numels_list)):
                xnumels, no_x_dim = (
                    (x_numels_list[i], False)
                    if isinstance(x_numels_list[i], str)
                    and cast(str, x_numels_list[i])[0] != "-"
                    or (
                        isinstance(x_numels_list[i], int)
                        and cast(int, x_numels_list[i]) > 0
                    )
                    else (kernel.min_x_blocks_list[i], True)
                )
                xblock_str = (
                    f"tl.cdiv({xnumels}, XBLOCK)" if not no_x_dim else f"{xnumels}"
                )
                if i == 0:
                    code.splice(f"num_xblocks_{i} = {xblock_str}")
                else:
                    code.splice(f"num_xblocks_{i} = num_xblocks_{i - 1} + {xblock_str}")

        @classmethod
        def grid(
            cls,
            sub_kernel_numels: List[List[int]],
            x_blocks_list: List[Union[str, int]],
            dynamic_shape: bool,
        ) -> Tuple[Any, ...]:
            xnumel = list(x_blocks_list)
            ynumel: Any = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
            znumel: Any = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]

            if dynamic_shape:
                ynumel = None if None in ynumel else ynumel
                znumel = None if None in znumel else znumel
            else:
                # TODO: improve 1d/2d mixed cases
                ynumel = (
                    None
                    if any(e is None for e in cast(List[Any], ynumel))
                    else max(cast(Iterable[int], ynumel))
                )
                znumel = (
                    None
                    if any(e is None for e in cast(List[Any], znumel))
                    else max(cast(Iterable[int], znumel))
                )

            numels = (
                (xnumel,)
                if not ynumel
                else (ynumel, xnumel)
                if not znumel
                else (znumel, ynumel, xnumel)
            )
            return numels

    class RoundRobinDispatch:
        """
        The dispatcher which dispatches the subkernels in a round robin manner:
        the blocks are interleavedly dispatched to each subkernel to execute them
        in parallel.
        The class defines the methods specific to the dispatch algorithm.
        Methods:
            codegen_pid_range(...): codegen the pid range for each subkernel.
            grid(...): codegen the grid size for launching the combo kernel.
        """

        @classmethod
        def codegen_pid_range(
            cls, kernel: "ComboKernel", num: int, code: IndentedBuffer
        ) -> None:
            num_kernels = len(kernel.sub_kernels)
            if num == 0:
                cond = "if"
            else:
                cond = "elif"
            code.splice(f"{cond} pid % {num_kernels} == {num}:")
            with code.indent():
                code.splice(f"pid_offset = pid // {num_kernels}")

        @classmethod
        def grid(
            cls,
            sub_kernel_numels: List[List[int]],
            x_blocks_list: List[Union[str, int]],
            dynamic_shape: bool,
        ) -> Tuple[Any, ...]:
            xnumel = x_blocks_list
            # set no_x_dim xnumels to 0
            xnumel_x_dim = [max(e, 0) for e in xnumel]
            ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
            znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]

            # TODO: support 1d/2d mixed cases
            xnumel = (
                None
                if any(e is None for e in xnumel)
                else xnumel
                if dynamic_shape
                else max(xnumel_x_dim)  # type: ignore[type-var, arg-type]
            )
            ynumel = (
                None
                if any(e is None for e in ynumel)
                else ynumel
                if dynamic_shape
                else max(ynumel)  # type: ignore[type-var, arg-type]
            )
            znumel = (
                None
                if any(e is None for e in znumel)
                else znumel
                if dynamic_shape
                else max(znumel)  # type: ignore[type-var, arg-type]
            )

            numels = (
                (xnumel,)
                if not ynumel
                else (ynumel, xnumel)
                if not znumel
                else (znumel, ynumel, xnumel)
            )
            return numels

    def __init__(
        self, enable_autotune: bool = False, mixed_sizes: bool = False
    ) -> None:
        super().__init__()
        self.sub_kernels: List[TritonKernel] = []
        self.iter_vars_count = itertools.count()
        self.grids: List[List[int]] = []
        self.min_x_blocks_list: List[Union[int, str]] = []
        self.x_numels_list: List[Union[int, str]] = []
        self.enable_autotune = enable_autotune
        self.mixed_sizes = mixed_sizes
        self.dispatch_class: Optional[
            Union[
                Type[ComboKernel.SequentialDispatch],
                Type[ComboKernel.RoundRobinDispatch],
            ]
        ] = None
        self.block_args: List[str] = []
        # there following are used when autotuning is disabled
        self.block_size_1d = 1024  # Try tuning this value
        self.block_size_2d = 32
        self.num_warps = 8
        self.block_size_reduce = 256
        self.dynamic_shape_args: List[str] = []

    def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel:
        sub_kernel = triton_kernel
        metrics.generated_kernel_count -= 1
        sub_kernel.args = self.args
        sub_kernel.iter_vars_count = self.iter_vars_count
        sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
        self.sub_kernels.append(sub_kernel)
        return sub_kernel

    @staticmethod
    def create_triton_kernel(
        tiling: Dict[str, sympy.Expr],
        features: SIMDKernelFeatures,
        optimize_mask: bool,
    ) -> TritonKernel:
        """
        Only allow optimize_mask=True when 1) sequential dispatch is used,
        2) numels except x dimension are the same for each sub kernel.
        """
        return TritonKernel(
            tiling,
            features=features,
            pid_cache={"tl.program_id(0)": "pid_offset"},
            optimize_mask=optimize_mask,
            # foreach kernels don't work with cooperative reductions
            override_cooperative_reduction=False,
        )

    def codegen_static_numels_sub_kernel(
        self, code: IndentedBuffer, sub_kernel: TritonKernel, num: int
    ) -> List[str]:
        """
        We get a small speedup from hard coding numels if they are static.

        This code stomps on the passed-in values by writing an constant to the top of the kernel.

        In a kernel like:
        def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):

        We would add
        xnumel = 4096
        rnumel = 768

        After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
        a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
        knows that its a static numel, as that you just plop a constant into the kernel.
        """
        grid = []
        uniquify_block_sizes = []
        for tree in sub_kernel.range_trees:
            simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
            if isinstance(simplified_tree_numel, (Integer, int)):
                code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
            else:
                assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args
                uniquify_block_sizes.append(f"{tree.prefix}numel")

            if not tree.is_reduction:
                if isinstance(simplified_tree_numel, (Integer, int)):
                    grid.append(int(simplified_tree_numel))
                else:
                    grid.append(f"{tree.prefix}numel_{num}")

            if tree.is_reduction and sub_kernel.persistent_reduction:
                if isinstance(simplified_tree_numel, (Integer, int)):
                    val = int(simplified_tree_numel)
                else:
                    raise RuntimeError(
                        "Dynamic shape on reduction dimension is not supported"
                    )
                val = next_power_of_2(val)
                code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}")
                uniquify_block_sizes.append("RBLOCK")

            if tree.prefix == "x" and sub_kernel.no_x_dim:
                code.writeline(f"XBLOCK_{num}: tl.constexpr = 1")
                uniquify_block_sizes.append("XBLOCK")
        self.grids.append(grid)
        return uniquify_block_sizes

    def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None:
        """
        Kernels with no_x_dim being true has no tunable XBLOCK. They have a fixed number of X blocks.
        Grid calculation needs to make sure that they are assigned with enough number of blocks.
        """
        min_x_blocks: Union[int, str] = 0
        x_numels: Union[int, str] = 0
        for tree in sub_kernel.range_trees:
            simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
            if tree.prefix == "x":
                if isinstance(simplified_tree_numel, (Integer, int)):
                    x_numels = int(simplified_tree_numel)
                else:
                    x_numels = f"{tree.prefix}numel_{num}"
                if sub_kernel.no_x_dim:
                    min_x_blocks = x_numels
                    x_numels = (
                        -min_x_blocks
                        if isinstance(x_numels, int)
                        else "-" + cast(str, x_numels)
                    )
                else:
                    if isinstance(simplified_tree_numel, (Integer, int)):
                        x_numels = int(simplified_tree_numel)
                    else:
                        x_numels = f"{tree.prefix}numel_{num}"
        self.min_x_blocks_list.append(min_x_blocks)
        self.x_numels_list.append(x_numels)

    def select_heuristics(self, sub_kernel: TritonKernel) -> Tuple[str, Dict[str, int]]:
        size_hints = {
            prefix: next_power_of_2(V.graph.sizevars.size_hint(numel))
            for prefix, numel in sub_kernel.numels.items()
            if not prefix_is_reduction(prefix) or sub_kernel.inside_reduction
        }
        if sub_kernel.persistent_reduction:
            assert sub_kernel.inside_reduction
            heuristics = "persistent_reduction"
        elif sub_kernel.inside_reduction:
            heuristics = "reduction"
        else:
            heuristics = "pointwise"
        return heuristics, size_hints

    def select_combo_heuristics(
        self, heuristics_list: List[str], size_hints_list: List[Dict[str, int]]
    ) -> Tuple[str, Dict[str, int], TritonKernel]:
        if not self.enable_autotune:
            return "foreach", size_hints_list[0], self.sub_kernels[0]
        if "reduction" in heuristics_list:
            i, _ = max(
                enumerate(size_hints_list),
                key=lambda x: x[1]["x"] if heuristics_list[x[0]] == "reduction" else 0,
            )
            return heuristics_list[i], size_hints_list[i], self.sub_kernels[i]
        elif "pointwise" in heuristics_list:
            i, _ = max(
                enumerate(size_hints_list),
                key=lambda x: x[1]["x"] if heuristics_list[x[0]] == "pointwise" else 0,
            )
            # modify size_hint to avoid oom check fail (may be a false alarm)
            num_pointwise = len([e for e in heuristics_list if e == "pointwise"])
            num_reduction = len([e for e in heuristics_list if e == "reduction"])
            num_persistent_reduction = len(
                [e for e in heuristics_list if e == "persistent_reduction"]
            )
            assert (
                num_reduction == 0
            ), "combining pointwise and reduction are not supported yet."
            heuristics = (
                "pointwise_with_reduction"
                if num_persistent_reduction > 0
                else "pointwise"
            )
            if len(heuristics_list) - num_pointwise >= 4:
                size_hints = size_hints_list[i]
                size_hints["x"] = min(128, size_hints["x"])
            return heuristics, size_hints_list[i], self.sub_kernels[i]
        else:
            return heuristics_list[0], size_hints_list[0], self.sub_kernels[0]

    def get_mutated_args_sub_kernels(self) -> List[str]:
        mutated_args = set()
        for sub_kernel in self.sub_kernels:
            for mutation in sub_kernel.mutations:
                if mutation in sub_kernel.args.input_buffers:
                    mutated_args.add(sub_kernel.args.input_buffers[mutation])
                if (
                    mutation in sub_kernel.args.inplace_buffers
                    and mutation not in V.graph.removed_buffers
                    and mutation not in sub_kernel.removed_buffers
                ):
                    mutated_args.add(
                        sub_kernel.args.inplace_buffers[mutation].inner_name
                    )
                if mutation in sub_kernel.args.output_buffers:
                    mutated_args.add(sub_kernel.args.output_buffers[mutation])
        return sorted(mutated_args)

    def select_dispatch_strategy(self) -> None:
        if self.dispatch_class is not None:
            return
        # mixed_sizes is used for optimize_mask, so it only allows sequential dispatch
        # Not mixed sizes on y dim technically is ok to use round robin as wells.
        if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list):
            # str in min_x_blocks_list means a dynamic shape
            self.dispatch_class = ComboKernel.SequentialDispatch
            return
        # A negative x_blocks_list element means the kernel is not tunable,
        # i.e., no_x_dim = True
        x_numels_list = [abs(cast(int, e)) for e in self.x_numels_list]
        total = max(x_numels_list) * len(x_numels_list)
        needed = sum(x_numels_list)
        if needed / total > BLOCK_UTILIZATION:
            # Introduced overhead (masked blocks) is less than 20%
            self.dispatch_class = ComboKernel.RoundRobinDispatch
        else:
            self.dispatch_class = ComboKernel.SequentialDispatch

    def jit_line(
        self,
        heuristics: str,
        size_hints: Dict[str, int],
        selected_kernel: TritonKernel,
        signature: List[Any],
        argdefs: List[str],
        pointwise_with_reduce: bool = False,
    ) -> str:
        can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
        size_dtype = "tl.int32" if can_use_32bit else "tl.int64"
        for i, sub in enumerate(self.sub_kernels):
            self.min_x_blocks_sub_kernel(sub, i)
        self.select_dispatch_strategy()
        triton_meta = {
            "signature": signature_to_meta(
                signature, size_dtype=size_dtype, argdefs=argdefs
            ),
            "device": DeviceProperties.create(V.graph.get_current_device_or_throw()),
            "constants": {},
        }
        triton_meta["configs"] = [config_of(signature)]
        mutated_args = self.get_mutated_args_sub_kernels()
        inductor_meta = {
            "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
            "mutated_arg_names": mutated_args,
            **TritonKernel.inductor_meta_common(),
        }

        sub_kernel = selected_kernel
        if heuristics == "foreach":
            heuristics_line = f"""
                @triton_heuristics.foreach(
                    num_warps={self.num_warps},
                    triton_meta={triton_meta!r},
                    inductor_meta={inductor_meta!r},
                )
                @triton.jit
            """
        elif sub_kernel.inside_reduction:
            reduction_hint = sub_kernel.features.get_reduction_hint()
            heuristics_line = f"""
                @triton_heuristics.{heuristics}(
                    size_hints={size_hints!r},
                    reduction_hint={reduction_hint},
                    filename=__file__,
                    triton_meta={triton_meta!r},
                    inductor_meta={inductor_meta!r}
                )
                @triton.jit
            """
        else:
            tile_hint = ""
            if len(size_hints) == 2:
                tile_hint = "tile_hint=TileHint.SQUARE,"
            else:
                tile_hint = "tile_hint=TileHint.DEFAULT,"
            heuristics_line = f"""
                @triton_heuristics.{heuristics}(
                    size_hints={size_hints!r}, {tile_hint}
                    filename=__file__,
                    triton_meta={triton_meta!r},
                    inductor_meta={inductor_meta!r}
                )
                @triton.jit
            """

        return heuristics_line

    def codegen_blocks(self, code: IndentedBuffer) -> None:
        for block in self.block_args:
            assert block in [
                "XBLOCK",
                "YBLOCK",
                "RBLOCK",
            ], f"{block} is not supported without autotuning"
        if "YBLOCK" in self.block_args:
            code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
            code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
        else:
            code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")
        if "RBLOCK" in self.block_args:
            code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}")

    def add_blockd_to_args(self, argdefs: List[str]) -> List[str]:
        block_args = {}
        block_names = {}
        for num, sub_kernel in enumerate(self.sub_kernels):
            # TODO: we assume all sub_kernels have the same block size
            for tree in sub_kernel.range_trees:
                if tree.is_reduction and (
                    not sub_kernel.inside_reduction or sub_kernel.persistent_reduction
                ):
                    continue
                if tree.prefix == "x" and sub_kernel.no_x_dim:
                    continue
                block_args[f"{tree.prefix.upper()}BLOCK : tl.constexpr"] = tree.prefix
                block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix
        if self.enable_autotune:
            argdefs.extend(block_args)
        self.block_args = list(block_names.keys())
        return argdefs

    def add_numel_to_args(self, argdefs: List[str], signature: List[Any]) -> List[str]:
        for num, sub_kernel in enumerate(self.sub_kernels):
            for tree in sub_kernel.active_range_trees():
                if not isinstance(tree.numel, (Integer, int)):
                    # only if it is a dynamic shape
                    sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel)
                    signature.append(sizearg)
                    argdefs.append(f"{tree.prefix}numel_{num}")
                    self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}")
        return argdefs

    def add_numel_to_call_args_and_grid(
        self, name: str, call_args: List[Any], arg_types: List[Any], grid: List[Any]
    ) -> None:
        for num, sub_kernel in enumerate(self.sub_kernels):
            for i, tree in enumerate(sub_kernel.range_trees):
                numel_name = f"{tree.prefix}numel_{num}"
                if numel_name not in self.dynamic_shape_args:
                    continue
                if isinstance(tree.numel, (Integer, Symbol)):
                    expr = tree.numel
                else:
                    expr = V.graph.wrapper_code.generate_numel_expr(
                        name, tree, suffix=str(num)
                    )
                if not tree.is_reduction:
                    assert isinstance(
                        grid[i][num], str
                    ), f"Grid {grid[i][num]} should be a dynamic shape."
                    numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
                    assert (
                        grid[i][num] == numel_sign + numel_name
                    ), f"numel args mismatch: {grid[i][num]} vs {numel_name}"
                    grid[i][num] = -expr if numel_sign == "-" else expr

                if not tree.is_reduction or sub_kernel.inside_reduction:
                    call_args.append(expr)
                    arg_types.append(type(expr))

    def add_numel_to_call_args_and_grid_benchmark(
        self, extra_args: List[Any], grid: Union[List[Any], Tuple[Any, ...]]
    ) -> None:
        for num, sub_kernel in enumerate(self.sub_kernels):
            for i, tree in enumerate(sub_kernel.range_trees):
                numel_name = f"{tree.prefix}numel_{num}"
                if numel_name not in self.dynamic_shape_args:
                    continue
                expr = V.graph.sizevars.size_hint(tree.numel)
                if not tree.is_reduction:
                    assert isinstance(
                        grid[i][num], str
                    ), f"Grid {grid[i][num]} should be a dynamic shape."
                    numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
                    assert (
                        grid[i][num] == numel_sign + numel_name
                    ), f"grid mismatch: {grid[i][num]} vs {numel_name}"
                    grid[i][num] = -expr if numel_sign == "-" else expr
                if not tree.is_reduction or sub_kernel.inside_reduction:
                    extra_args.append(expr)

    def codegen_kernel(self, name: Optional[str] = None) -> str:
        # TODO: is it correct to use the first sub kernel's heuristics?
        heuristics_list, size_hints_list = [], []
        for subkernel in self.sub_kernels:
            h, s = self.select_heuristics(subkernel)
            heuristics_list.append(h)
            size_hints_list.append(s)
        heuristics, size_hints, selected_kernel = self.select_combo_heuristics(
            heuristics_list, size_hints_list
        )
        pointwise_with_reduction, heuristics = (
            (True, "pointwise")
            if heuristics == "pointwise_with_reduction"
            else (False, heuristics)
        )
        code = IndentedBuffer()

        code.splice(gen_common_triton_imports())
        if config.benchmark_combo_kernel:
            code.splice(self.imports_for_benchmark_kernel())

        argdefs, _, signature, _ = self.args.python_argdefs()
        argdefs = self.add_numel_to_args(argdefs, signature)
        argdefs = self.add_blockd_to_args(argdefs)
        code.splice(
            self.jit_line(
                heuristics,
                size_hints,
                selected_kernel,
                pointwise_with_reduce=pointwise_with_reduction,
                signature=signature,
                argdefs=argdefs,
            )
        )
        code.writeline(
            f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
        )

        with code.indent():
            code.splice("pid = tl.program_id(0)")
            if not self.enable_autotune:
                self.codegen_blocks(code)

            for num, sub_kernel in enumerate(self.sub_kernels):
                assert self.dispatch_class is not None
                self.dispatch_class.codegen_pid_range(self, num, code)
                with code.indent():
                    uniquify = self.codegen_static_numels_sub_kernel(
                        code, sub_kernel, num
                    )
                    sub_kernel.codegen_body()
                    uniquified_body = self.uniquify_block_sizes(
                        sub_kernel.body, num, uniquify
                    )
                    code.splice(uniquified_body)

            code.splice("else:")
            with code.indent():
                code.splice("pass")

        if config.benchmark_combo_kernel:
            code.splice(self.codegen_kernel_benchmark(num_gb=0))

        return code.getvalue()

    def codegen_kernel_benchmark(
        self, num_gb: float, grid: Optional[List[Any]] = None
    ) -> IndentedBuffer:
        result = IndentedBuffer()
        argdefs, call_args, signature, _ = self.args.python_argdefs()

        result.writelines(["", "", "def get_args():"])
        with result.indent():
            name_cnt = itertools.count()
            var_names = []
            for arg_name, arg_sig in zip(call_args, signature):
                var_name = f"arg_{next(name_cnt)}"
                buf = V.graph.try_get_buffer(arg_name)
                if buf:
                    result.writeline(
                        f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})"  # noqa: B950 line too long
                    )
                elif arg_name in V.graph.constants:
                    # note that random seed is put in V.graph.constants
                    const_tensor = V.graph.constants[arg_name]
                    result.writeline(
                        f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})"  # type: ignore[arg-type]  # noqa: B950 line too long
                    )
                elif isinstance(arg_sig, SizeArg):
                    symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)

                    # Force the seed_offset to be 0 so calls to the same kernel
                    # using different seed offset will have the same benchmark harness.
                    # We can dedup kernel definitions in this case.
                    if "seed_offset" in arg_sig.name:
                        symval_hint = 0
                    result.writeline(f"{var_name} = {symval_hint}")
                elif isinstance(arg_sig, WorkspaceArg):
                    device = V.graph.get_current_device_or_throw()
                    count = V.graph.sizevars.size_hint(arg_sig.count)
                    # for benchmark harness, we ignore arg_sig.zero_mode and always zero it
                    result.writeline(
                        f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})"
                    )
                else:
                    raise KeyError(
                        f"Don't find the buffer or const tensor for {arg_name}"
                    )
                var_names.append(var_name)
            result.writeline(f"return {', '.join(var_names)},")

        result.writelines(["\n", "\n", "def call(args):"])
        if grid is None:
            assert self.dispatch_class is not None
            dynamic_shape = self.dynamic_shape_args != []
            grid_tuple = self.dispatch_class.grid(
                self.grids, self.x_numels_list, dynamic_shape
            )
            extra_args_str = ""
            extra_args: List[Any] = []
            if dynamic_shape:
                self.add_numel_to_call_args_and_grid_benchmark(extra_args, grid_tuple)
                # convert nested list to list of str
                grid_tuple = tuple(
                    "[" + ", ".join(pexpr(item) for item in e) + ",]"
                    for e in grid_tuple
                )
                extra_args_str = ", ".join(map(str, extra_args)) + ", "
                min_blocks = None
            else:
                min_blocks = max(self.min_x_blocks_list) * len(self.sub_kernels)
            grid_str = ", ".join(pexpr(item) for item in grid_tuple)
            grid_extra_kwargs = (
                f"num_kernels={len(self.sub_kernels)}, "
                f"min_blocks={min_blocks}, "
                f"is_sequential={self.dispatch_class is self.SequentialDispatch}"
            )
            grid_str = f"{grid_str}, {grid_extra_kwargs}"
            grid_arg = f"{extra_args_str}grid=grid_combo_kernels({grid_str})"
        else:
            grid_arg = f"grid={grid}"
        index = V.graph.get_current_device_or_throw().index
        with result.indent():
            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
            with result.indent():
                result.writeline(
                    V.graph.device_ops.set_device(index)
                )  # no-op to ensure context
                stream_name = f"stream{index}"
                result.writeline(f"{stream_name} = get_raw_stream({index})")
                result.writeline(
                    f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
                )

        # benchmark all configs
        result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
        with result.indent():
            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
            with result.indent():
                result.writeline(
                    V.graph.device_ops.set_device(index)
                )  # no-op to ensure context
                result.writeline(
                    f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
                )

        result.writelines(["\n", "\n", "if __name__ == '__main__':"])
        with result.indent():
            result.writeline(
                "from torch._inductor.runtime.benchmarking import benchmarker"
            )
            result.writeline("")

            result.writeline("args = get_args()")
            result.writeline(
                "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)"
            )
            result.writeline(f"num_gb = {num_gb}")
            result.writeline("gb_per_s = num_gb / (ms / 1e3)")
            result.writeline(
                'print(f"{ms:.3f}ms    {num_gb:.3f}GB    {gb_per_s:.2f}GB/s")'
            )

        return result

    def imports_for_benchmark_kernel(self) -> str:
        return textwrap.dedent(
            """
            from torch._dynamo.testing import rand_strided
            {}
            import torch
            from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
        """.format(
                V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
            )
        )

    def uniquify_block_sizes(
        self, code: IndentedBuffer, num_kernel: int, uniquify: List[str]
    ) -> IndentedBuffer:
        if not uniquify:
            return code
        modified = IndentedBuffer(initial_indent=code._indent)
        for line in code._lines:
            if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]):
                modified_line = line
                for block in blocks:
                    modified_line = modified_line.replace(
                        block, f"{block}_{num_kernel}"
                    )
                modified.writeline(modified_line)
            elif isinstance(line, DeferredLine) and (
                blocks := [e for e in uniquify if e in line.line]
            ):
                modified_line = line.line
                for block in blocks:
                    modified_line = modified_line.replace(
                        block, f"{block}_{num_kernel}"
                    )
                new_line = DeferredLine(line.name, modified_line)
                modified.writeline(new_line)
            else:
                modified.writeline(line)
        return modified

    def call_kernel(self, code: IndentedBuffer, name: str) -> None:
        _, call_args, _, arg_types = self.args.python_argdefs()

        wrapper = V.graph.wrapper_code
        assert self.dispatch_class is not None
        dynamic_shape = self.dynamic_shape_args != []
        grid = list(
            self.dispatch_class.grid(self.grids, self.x_numels_list, dynamic_shape)
        )
        num_kernels = len(self.sub_kernels)
        min_blocks = (
            max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None
        )
        is_sequential = self.dispatch_class is self.SequentialDispatch
        if dynamic_shape:
            self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid)
            # convert nested list to list of str
            # grid = tuple("["+", ".join(pexpr(item) for item in e)+",]" for e in grid)
        if not self.enable_autotune and not dynamic_shape:
            launch_grid = self.grid_no_autotune(
                grid, num_kernels, cast(int, min_blocks), is_sequential
            )
            V.graph.wrapper_code.generate_kernel_call(
                name,
                call_args,
                grid=launch_grid,
                arg_types=arg_types,
                grid_fn="",
            )
            return
        # autotuning is enabled
        grid = wrapper.generate_default_grid(
            name,
            list(grid),
            grid_callable=grid_combo_kernels,
            num_kernels=num_kernels,
            min_blocks=min_blocks,
            is_sequential=is_sequential,
            default_meta=None if self.enable_autotune else self.get_default_meta(),
        )
        wrapper.generate_kernel_call(
            name,
            call_args,
            grid,
            V.graph.get_current_device_or_throw().index,
            gpu=True,
            triton=True,
            arg_types=arg_types,
            grid_fn="grid_combo_kernels",
            grid_extra_kwargs=(
                f"num_kernels={num_kernels}, "
                f"min_blocks={min_blocks}, "
                f"is_sequential={is_sequential}, "
                f"default_meta={None if self.enable_autotune else self.get_default_meta()}"
            ),
        )

    def grid_no_autotune(
        self,
        grid: Union[Tuple[Any], List[Any]],
        num_kernels: int,
        min_blocks: int,
        is_sequential: bool,
    ) -> List[int]:
        meta = self.get_default_meta()
        grid_func = grid_combo_kernels(
            *grid,
            num_kernels=num_kernels,
            min_blocks=min_blocks,
            is_sequential=is_sequential,
        )
        return grid_func(meta)

    def get_default_meta(self) -> Dict[str, int]:
        if "YBLOCK" in self.block_args:
            meta = {"XBLOCK": self.block_size_2d, "YBLOCK": self.block_size_2d}
        else:
            meta = {"XBLOCK": self.block_size_1d}
        return meta